mirror of
https://github.com/openbao/openbao.git
synced 2026-06-01 18:57:37 +02:00
Apply additional "go fix" analyzers (#2819)
Signed-off-by: Jonas Köhnen <jonas.koehnen@sap.com>
This commit is contained in:
@@ -128,16 +128,14 @@ func TestAppRole_TidyDanglingAccessors_RaceTest(t *testing.T) {
|
||||
true,
|
||||
)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
roleSecretIDReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "role/role1/secret-id",
|
||||
Storage: storage,
|
||||
}
|
||||
_ = b.requestNoErr(t, roleSecretIDReq)
|
||||
}()
|
||||
})
|
||||
|
||||
entry, err := logical.StorageEntryJSON(
|
||||
fmt.Sprintf("accessor/invalid%d", count),
|
||||
|
||||
@@ -214,7 +214,7 @@ func (b *jwtAuthBackend) runCelProgram(ctx context.Context, operation logical.Op
|
||||
}
|
||||
|
||||
// handle protobuf Auth return type
|
||||
if msg, err := result.ConvertToNative(reflect.TypeOf(&pb.Auth{})); err == nil {
|
||||
if msg, err := result.ConvertToNative(reflect.TypeFor[*pb.Auth]()); err == nil {
|
||||
pbAuth, ok := msg.(*pb.Auth)
|
||||
if ok {
|
||||
return pbAuth, nil
|
||||
|
||||
@@ -189,7 +189,7 @@ func TestGSuiteProvider_FetchUserInfo(t *testing.T) {
|
||||
|
||||
// Assert that expected user info is added to the JWT claims
|
||||
customSchemas := tt.args.config.ProviderConfig["user_custom_schemas"].(string)
|
||||
for _, schema := range strings.Split(customSchemas, ",") {
|
||||
for schema := range strings.SplitSeq(customSchemas, ",") {
|
||||
assert.Equal(t, tt.expected[schema], allClaims[schema])
|
||||
}
|
||||
})
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
maps0 "maps"
|
||||
"maps"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -261,10 +261,10 @@ func makeRoleType(roleType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func combineMaps(maps ...map[string]string) map[string]string {
|
||||
func combineMaps(combine ...map[string]string) map[string]string {
|
||||
newMap := make(map[string]string)
|
||||
for _, m := range maps {
|
||||
maps0.Copy(newMap, m)
|
||||
for _, m := range combine {
|
||||
maps.Copy(newMap, m)
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ package integrationtest
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
maps0 "maps"
|
||||
"maps"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -403,10 +403,10 @@ func testK8sTokenAudiences(t *testing.T, expectedAudiences []interface{}, token
|
||||
assert.ElementsMatch(t, expectedAudiences, aud)
|
||||
}
|
||||
|
||||
func combineMaps(maps ...map[string]string) map[string]string {
|
||||
func combineMaps(combine ...map[string]string) map[string]string {
|
||||
newMap := make(map[string]string)
|
||||
for _, m := range maps {
|
||||
maps0.Copy(newMap, m)
|
||||
for _, m := range combine {
|
||||
maps.Copy(newMap, m)
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func newFieldRegistry() *fieldRegistry {
|
||||
|
||||
registryFields := vOfReg.Elem()
|
||||
for i := 0; i < registryFields.NumField(); i++ {
|
||||
if registryFields.Field(i).Kind() == reflect.Ptr {
|
||||
if registryFields.Field(i).Kind() == reflect.Pointer {
|
||||
|
||||
field := registryFields.Type().Field(i)
|
||||
ldapString := field.Tag.Get("ldap")
|
||||
|
||||
@@ -34,10 +34,10 @@ type ChallengeValidation struct {
|
||||
Thumbprint string `json:"thumbprint"`
|
||||
|
||||
Initiated time.Time `json:"initiated"`
|
||||
FirstValidation time.Time `json:"first_validation,omitempty"`
|
||||
FirstValidation time.Time `json:"first_validation"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
LastRetry time.Time `json:"last_retry,omitempty"`
|
||||
RetryAfter time.Time `json:"retry_after,omitempty"`
|
||||
LastRetry time.Time `json:"last_retry"`
|
||||
RetryAfter time.Time `json:"retry_after"`
|
||||
}
|
||||
|
||||
type ChallengeQueueEntry struct {
|
||||
|
||||
@@ -140,8 +140,7 @@ func (b *backend) acmeParsedWrapper(op acmeParsedOperation) framework.OperationF
|
||||
resp.Headers["Link"] = genAcmeLinkHeader(acmeCtx)
|
||||
} else {
|
||||
directory := genAcmeLinkHeader(acmeCtx)[0]
|
||||
addDirectory := !slices.Contains(resp.Headers["Link"], directory)
|
||||
if addDirectory {
|
||||
if !slices.Contains(resp.Headers["Link"], directory) {
|
||||
resp.Headers["Link"] = append(resp.Headers["Link"], directory)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,19 +310,11 @@ func validateCommonName(b *backend, data *inputBundle, name string) string {
|
||||
// If there's an at in the data, ensure email type validation is allowed.
|
||||
// Otherwise, ensure hostname is allowed.
|
||||
if strings.Contains(name, "@") {
|
||||
var allowsEmails bool
|
||||
if slices.Contains(data.role.CNValidations, "email") {
|
||||
allowsEmails = true
|
||||
}
|
||||
if !allowsEmails {
|
||||
if !slices.Contains(data.role.CNValidations, "email") {
|
||||
return name
|
||||
}
|
||||
} else {
|
||||
var allowsHostnames bool
|
||||
if slices.Contains(data.role.CNValidations, "hostname") {
|
||||
allowsHostnames = true
|
||||
}
|
||||
if !allowsHostnames {
|
||||
if !slices.Contains(data.role.CNValidations, "hostname") {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ type CBValidateChain struct {
|
||||
func (c CBValidateChain) ChainToPEMs(t testing.TB, parent string, chain []string, knownCerts map[string]string) []string {
|
||||
var result []string
|
||||
for entryIndex, entry := range chain {
|
||||
var chainEntry string
|
||||
var chainEntry strings.Builder
|
||||
modifiedEntry := entry
|
||||
if entryIndex == 0 && entry == "self" {
|
||||
modifiedEntry = parent
|
||||
@@ -241,9 +241,9 @@ func (c CBValidateChain) ChainToPEMs(t testing.TB, parent string, chain []string
|
||||
cert, ok := knownCerts[issuer]
|
||||
require.Truef(t, ok, "Unknown issuer %v in chain for %v: %v", issuer, parent, chain)
|
||||
|
||||
chainEntry += cert
|
||||
chainEntry.WriteString(cert)
|
||||
}
|
||||
result = append(result, chainEntry)
|
||||
result = append(result, chainEntry.String())
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/openbao/openbao/sdk/v2/helper/errutil"
|
||||
)
|
||||
@@ -68,9 +69,7 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
|
||||
// Our provided reference cert might not be in the list of issuers. In
|
||||
// that case, add it manually.
|
||||
if referenceCert != nil {
|
||||
missing := !slices.Contains(issuers, referenceCert.ID)
|
||||
|
||||
if missing {
|
||||
if !slices.Contains(issuers, referenceCert.ID) {
|
||||
issuers = append(issuers, referenceCert.ID)
|
||||
}
|
||||
}
|
||||
@@ -425,15 +424,15 @@ func (sc *storageContext) rebuildIssuersChains(referenceCert *issuerEntry /* opt
|
||||
// Assumption: no nodes left unprocessed. They should've either been
|
||||
// reached through the parent->child addition or they should've been
|
||||
// self-loops.
|
||||
var msg string
|
||||
var msg strings.Builder
|
||||
for _, issuer := range issuers {
|
||||
if visited, ok := processedIssuers[issuer]; !ok || !visited {
|
||||
pretty := prettyIssuer(issuerIdEntryMap, issuer)
|
||||
msg += fmt.Sprintf("[failed to build chain correctly: unprocessed issuer %v: ok: %v; visited: %v]\n", pretty, ok, visited)
|
||||
msg.WriteString(fmt.Sprintf("[failed to build chain correctly: unprocessed issuer %v: ok: %v; visited: %v]\n", pretty, ok, visited))
|
||||
}
|
||||
}
|
||||
if len(msg) > 0 {
|
||||
return errors.New(msg)
|
||||
if len(msg.String()) > 0 {
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
// Finally, write all issuers to disk.
|
||||
@@ -653,9 +652,7 @@ func processAnyCliqueOrCycle(
|
||||
// the nodes of whatever grouping
|
||||
foundNode := false
|
||||
for _, clique := range cliques {
|
||||
inClique := slices.Contains(clique, node)
|
||||
|
||||
if inClique {
|
||||
if slices.Contains(clique, node) {
|
||||
foundNode = true
|
||||
|
||||
// Compute this node's CAChain. Note order doesn't matter
|
||||
@@ -946,14 +943,6 @@ func isOnReissuedClique(
|
||||
return clique, nil
|
||||
}
|
||||
|
||||
func containsIssuer(collection []issuerID, target issuerID) bool {
|
||||
if len(collection) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return slices.Contains(collection, target)
|
||||
}
|
||||
|
||||
func appendCycleIfNotExisting(knownCycles [][]issuerID, candidate []issuerID) [][]issuerID {
|
||||
// There's two ways to do cycle detection: canonicalize the cycles,
|
||||
// rewriting them to have the least (or max) element first or just
|
||||
@@ -1034,7 +1023,7 @@ func findCyclesNearClique(
|
||||
// We know the node has at least one child, since the clique is non-empty.
|
||||
for _, child := range issuerIdChildrenMap[cliqueNode] {
|
||||
// Skip children that are part of the clique.
|
||||
if containsIssuer(excludeNodes, child) {
|
||||
if slices.Contains(excludeNodes, child) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1130,9 +1119,7 @@ func findAllCyclesWithNode(
|
||||
continue
|
||||
}
|
||||
|
||||
skipNode := slices.Contains(exclude, child)
|
||||
|
||||
if skipNode {
|
||||
if slices.Contains(exclude, child) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1164,8 +1151,7 @@ func findAllCyclesWithNode(
|
||||
// We only care about source->source cycles. If this
|
||||
// cycles, but isn't a source->source cycle, don't add
|
||||
// this path.
|
||||
foundSelf := slices.Contains(path, child)
|
||||
if foundSelf {
|
||||
if slices.Contains(path, child) {
|
||||
// Skip this path.
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -130,14 +130,15 @@ func (ts *TestServer) buildNamedConf() string {
|
||||
forwarders += "\t};\n"
|
||||
}
|
||||
|
||||
zones := "\n"
|
||||
var zones strings.Builder
|
||||
zones.WriteString("\n")
|
||||
for _, domain := range ts.domains {
|
||||
zones += fmt.Sprintf("zone \"%s\" {\n", domain)
|
||||
zones += "\ttype primary;\n"
|
||||
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain)
|
||||
zones += "\tallow-update {\n\t\tnone;\n\t};\n"
|
||||
zones += "\tnotify no;\n"
|
||||
zones += "};\n\n"
|
||||
zones.WriteString(fmt.Sprintf("zone \"%s\" {\n", domain))
|
||||
zones.WriteString("\ttype primary;\n")
|
||||
zones.WriteString(fmt.Sprintf("\tfile \"%s.zone\";\n", domain))
|
||||
zones.WriteString("\tallow-update {\n\t\tnone;\n\t};\n")
|
||||
zones.WriteString("\tnotify no;\n")
|
||||
zones.WriteString("};\n\n")
|
||||
}
|
||||
|
||||
// Reverse lookups are not handles as they're not presently necessary.
|
||||
@@ -150,20 +151,21 @@ func (ts *TestServer) buildNamedConf() string {
|
||||
` + forwarders + `
|
||||
};
|
||||
|
||||
` + zones
|
||||
` + zones.String()
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (ts *TestServer) buildZoneFile(target string) string {
|
||||
// One second TTL by default to allow quick refreshes.
|
||||
zone := "$TTL 1;\n"
|
||||
var zone strings.Builder
|
||||
zone.WriteString("$TTL 1;\n")
|
||||
|
||||
ts.serial += 1
|
||||
zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target)
|
||||
zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial)
|
||||
zone += fmt.Sprintf("@\tIN\tNS\tns%d.%v.\n", ts.serial, target)
|
||||
zone += fmt.Sprintf("ns%d.%v.\tIN\tA\t%v\n", ts.serial, target, "127.0.0.1")
|
||||
zone.WriteString(fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target))
|
||||
zone.WriteString(fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial))
|
||||
zone.WriteString(fmt.Sprintf("@\tIN\tNS\tns%d.%v.\n", ts.serial, target))
|
||||
zone.WriteString(fmt.Sprintf("ns%d.%v.\tIN\tA\t%v\n", ts.serial, target, "127.0.0.1"))
|
||||
|
||||
for domain, records := range ts.records {
|
||||
if !strings.HasSuffix(domain, target) {
|
||||
@@ -172,12 +174,12 @@ func (ts *TestServer) buildZoneFile(target string) string {
|
||||
|
||||
for recordType, values := range records {
|
||||
for _, value := range values {
|
||||
zone += fmt.Sprintf("%s.\tIN\t%s\t%s\n", domain, recordType, value)
|
||||
zone.WriteString(fmt.Sprintf("%s.\tIN\t%s\t%s\n", domain, recordType, value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return zone
|
||||
return zone.String()
|
||||
}
|
||||
|
||||
func (ts *TestServer) pushNamedConf() {
|
||||
|
||||
@@ -996,8 +996,7 @@ func (b *backend) getRole(ctx context.Context, s logical.Storage, n string) (*ro
|
||||
modified = true
|
||||
}
|
||||
if result.AllowedBaseDomain != "" {
|
||||
found := slices.Contains(result.AllowedDomains, result.AllowedBaseDomain)
|
||||
if !found {
|
||||
if !slices.Contains(result.AllowedDomains, result.AllowedBaseDomain) {
|
||||
result.AllowedDomains = append(result.AllowedDomains, result.AllowedBaseDomain)
|
||||
}
|
||||
result.AllowedBaseDomain = ""
|
||||
|
||||
@@ -479,9 +479,10 @@ func RunNginxRootTest(t *testing.T, caKeyType string, caKeyBits int, caUsePSS bo
|
||||
requireSuccessNonNilResponse(t, resp, err, "failed to create server leaf cert")
|
||||
leafCert := resp.Data["certificate"].(string)
|
||||
leafPrivateKey := resp.Data["private_key"].(string) + "\n"
|
||||
fullChain := leafCert + "\n"
|
||||
var fullChain strings.Builder
|
||||
fullChain.WriteString(leafCert + "\n")
|
||||
for _, cert := range resp.Data["ca_chain"].([]string) {
|
||||
fullChain += cert + "\n"
|
||||
fullChain.WriteString(cert + "\n")
|
||||
}
|
||||
|
||||
// Issue a client leaf certificate.
|
||||
@@ -546,7 +547,7 @@ func RunNginxRootTest(t *testing.T, caKeyType string, caKeyBits int, caUsePSS bo
|
||||
|
||||
crls := rootCRL + intCRL + deltaCRL
|
||||
|
||||
cleanup, host, port, networkName, networkAddr, networkPort := buildNginxContainer(t, rootCert, crls, fullChain, leafPrivateKey)
|
||||
cleanup, host, port, networkName, networkAddr, networkPort := buildNginxContainer(t, rootCert, crls, fullChain.String(), leafPrivateKey)
|
||||
defer cleanup()
|
||||
|
||||
if host != "127.0.0.1" && host != "::1" && strings.HasPrefix(host, containerName) {
|
||||
|
||||
@@ -136,9 +136,7 @@ func RunZLintRootTest(t *testing.T, keyType string, keyBits int, usePSS bool, ig
|
||||
}
|
||||
|
||||
if result == "error" {
|
||||
skip := slices.Contains(ignored, key)
|
||||
|
||||
if !skip {
|
||||
if !slices.Contains(ignored, key) {
|
||||
t.Fatalf("got unexpected error from test %v: %v", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,8 +281,8 @@ func validateUsername(username, allowedUsers string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
userList := strings.Split(allowedUsers, ",")
|
||||
for _, user := range userList {
|
||||
userList := strings.SplitSeq(allowedUsers, ",")
|
||||
for user := range userList {
|
||||
if strings.TrimSpace(user) == username {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func cidrListContainsIP(ip, cidrList string) (bool, error) {
|
||||
if len(cidrList) == 0 {
|
||||
return false, errors.New("IP does not belong to role")
|
||||
}
|
||||
for _, item := range strings.Split(cidrList, ",") {
|
||||
for item := range strings.SplitSeq(cidrList, ",") {
|
||||
_, cidrIPNet, err := net.ParseCIDR(item)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid CIDR entry %q", item)
|
||||
|
||||
@@ -6,13 +6,12 @@ package transit
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/sha3"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
|
||||
"github.com/openbao/openbao/sdk/v2/framework"
|
||||
"github.com/openbao/openbao/sdk/v2/logical"
|
||||
)
|
||||
|
||||
+24
-48
@@ -380,14 +380,12 @@ listener "tcp" {
|
||||
var output string
|
||||
var code int
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code = cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -588,16 +586,14 @@ auto_auth {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code := cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
t.Errorf("non-zero return code when running agent: %d", code)
|
||||
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
|
||||
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -1055,16 +1051,14 @@ auto_auth {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code := cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
t.Errorf("non-zero return code when running agent: %d", code)
|
||||
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
|
||||
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -1230,16 +1224,14 @@ exit_after_auth = true
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code := cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
t.Errorf("non-zero return code when running agent: %d", code)
|
||||
t.Logf("STDOUT from agent:\n%s", ui.OutputWriter.String())
|
||||
t.Logf("STDERR from agent:\n%s", ui.ErrorWriter.String())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -1782,11 +1774,9 @@ api_proxy {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -1873,11 +1863,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -1965,11 +1953,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -2041,11 +2027,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -2213,11 +2197,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -2596,14 +2578,12 @@ listener "tcp" {
|
||||
var output string
|
||||
var code int
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code = cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -2706,11 +2686,9 @@ cache {}
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -3035,13 +3013,11 @@ vault {
|
||||
var output string
|
||||
var code int
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
if code = cmd.Run([]string{"-config", configPath}); code != 0 {
|
||||
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
|
||||
+4
-12
@@ -550,11 +550,7 @@ func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
// 100 concurrent requests
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
// Send a request through the lease cache which is not cacheable (there is
|
||||
// no lease information or auth information in the response)
|
||||
sendReq := &SendRequest{
|
||||
@@ -565,7 +561,7 @@ func TestLeaseCache_Concurrent_NonCacheable(t *testing.T) {
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
@@ -602,11 +598,7 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
// Start 100 concurrent requests
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
sendReq := &SendRequest{
|
||||
Token: "autoauthtoken",
|
||||
Request: httptest.NewRequest("GET", "http://example.com/v1/sample/api", nil),
|
||||
@@ -620,7 +612,7 @@ func TestLeaseCache_Concurrent_Cacheable(t *testing.T) {
|
||||
if resp.CacheMeta != nil && resp.CacheMeta.Hit {
|
||||
cacheCount.Add(1)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
@@ -69,9 +69,7 @@ func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
for range 100 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
appRoleAuth, err := auth.NewAppRoleAuth(roleID, &auth.SecretID{FromString: secretID})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -86,8 +84,7 @@ func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
|
||||
t.Error("expected a successful login")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
+6
-12
@@ -767,9 +767,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
|
||||
|
||||
// As a convenience, we'll also fetch the goroutine target using debug=2, which yields a text
|
||||
// version of the stack traces that don't require using `go tool pprof` to view.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
data, err := pprofTarget(ctx, c.cachedClient, "goroutine", url.Values{"debug": []string{"2"}})
|
||||
if err != nil {
|
||||
c.captureError("pprof.goroutines-text", err)
|
||||
@@ -780,7 +778,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
|
||||
if err != nil {
|
||||
c.captureError("pprof.goroutines-text", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// If the our remaining duration is less than the interval value
|
||||
// skip profile and trace.
|
||||
@@ -791,9 +789,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Capture profile
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
data, err := pprofProfile(ctx, c.cachedClient, c.flagInterval)
|
||||
if err != nil {
|
||||
c.captureError("pprof.profile", err)
|
||||
@@ -804,12 +800,10 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
|
||||
if err != nil {
|
||||
c.captureError("pprof.profile", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Capture trace
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
data, err := pprofTrace(ctx, c.cachedClient, c.flagInterval)
|
||||
if err != nil {
|
||||
c.captureError("pprof.trace", err)
|
||||
@@ -820,7 +814,7 @@ func (c *DebugCommand) collectPprof(ctx context.Context) {
|
||||
if err != nil {
|
||||
c.captureError("pprof.trace", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -156,9 +156,7 @@ func (h *AuditVisibility) Evaluate(e *Executor) (results []*Result, err error) {
|
||||
}
|
||||
|
||||
for _, param := range visibleList {
|
||||
found := slices.Contains(actual, param)
|
||||
|
||||
if !found {
|
||||
if !slices.Contains(actual, param) {
|
||||
ret := Result{
|
||||
Status: ResultInformational,
|
||||
Endpoint: "/sys/mounts/{{mount}}/tune",
|
||||
@@ -179,9 +177,7 @@ func (h *AuditVisibility) Evaluate(e *Executor) (results []*Result, err error) {
|
||||
return nil, fmt.Errorf("error parsing %v from server: %v", source, err)
|
||||
}
|
||||
for _, param := range hiddenList {
|
||||
found := slices.Contains(actual, param)
|
||||
|
||||
if found {
|
||||
if slices.Contains(actual, param) {
|
||||
ret := Result{
|
||||
Status: ResultWarning,
|
||||
Endpoint: "/sys/mounts/{{mount}}/tune",
|
||||
|
||||
@@ -598,7 +598,7 @@ func TestOperatorRekeyCommand_Run(t *testing.T) {
|
||||
}
|
||||
nonce := status.Nonce
|
||||
|
||||
var combined string
|
||||
var combined strings.Builder
|
||||
// Supply the unseal keys
|
||||
for _, key := range keys {
|
||||
ui, cmd := testOperatorRekeyCommand(t)
|
||||
@@ -613,11 +613,11 @@ func TestOperatorRekeyCommand_Run(t *testing.T) {
|
||||
}
|
||||
|
||||
// Append to our output string
|
||||
combined += ui.OutputWriter.String()
|
||||
combined.WriteString(ui.OutputWriter.String())
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`Key 1 fingerprint: (.+); value: (.+)`)
|
||||
match := re.FindAllStringSubmatch(combined, -1)
|
||||
match := re.FindAllStringSubmatch(combined.String(), -1)
|
||||
if len(match) < 1 || len(match[0]) < 3 {
|
||||
t.Fatalf("bad match: %#v", match)
|
||||
}
|
||||
|
||||
@@ -496,7 +496,7 @@ func TestOperatorRotateKeysCommand_Run(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
nonce := status.Nonce
|
||||
|
||||
var combined string
|
||||
var combined strings.Builder
|
||||
// Supply the unseal keys
|
||||
for _, key := range keys {
|
||||
ui, cmd := testOperatorRotateKeysCommand(t)
|
||||
@@ -509,11 +509,11 @@ func TestOperatorRotateKeysCommand_Run(t *testing.T) {
|
||||
require.Equalf(t, 0, code, "expected %d to be %d: %s", code, 0, ui.ErrorWriter.String())
|
||||
|
||||
// Append to our output string
|
||||
combined += ui.OutputWriter.String()
|
||||
combined.WriteString(ui.OutputWriter.String())
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`Key 1 fingerprint: (.+); value: (.+)`)
|
||||
match := re.FindAllStringSubmatch(combined, -1)
|
||||
match := re.FindAllStringSubmatch(combined.String(), -1)
|
||||
require.False(t, len(match) < 1 || len(match[0]) < 3)
|
||||
|
||||
// Grab the output fingerprint and encrypted key
|
||||
|
||||
@@ -173,11 +173,11 @@ func pkiIssue(c *BaseCommand, parentMountIssuer string, intermediateMount string
|
||||
failureState.certSerialNumber = serialNumber
|
||||
|
||||
caChain := rootResp.Data["ca_chain"].([]interface{})
|
||||
caChainPemBundle := ""
|
||||
var caChainPemBundle strings.Builder
|
||||
for _, cert := range caChain {
|
||||
caChainPemBundle += cert.(string) + "\n"
|
||||
caChainPemBundle.WriteString(cert.(string) + "\n")
|
||||
}
|
||||
failureState.caChain = caChainPemBundle
|
||||
failureState.caChain = caChainPemBundle.String()
|
||||
|
||||
// Next Import Certificate
|
||||
certificate := rootResp.Data["certificate"].(string)
|
||||
@@ -214,7 +214,7 @@ func pkiIssue(c *BaseCommand, parentMountIssuer string, intermediateMount string
|
||||
// Finally Import CA_Chain (just in case there's more information)
|
||||
if len(caChain) > 2 { // We've already imported parent cert and newly issued cert above
|
||||
importData := map[string]interface{}{
|
||||
"pem_bundle": caChainPemBundle,
|
||||
"pem_bundle": caChainPemBundle.String(),
|
||||
}
|
||||
_, err := client.Logical().Write(intermediateMount+"/issuers/import/cert", importData)
|
||||
if err != nil {
|
||||
|
||||
+14
-28
@@ -363,11 +363,9 @@ api_proxy {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -457,11 +455,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -549,11 +545,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -627,11 +621,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -799,11 +791,9 @@ vault {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -876,16 +866,14 @@ listener "tcp" {
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
code := cmd.Run([]string{"-config", configPath})
|
||||
if code != 0 {
|
||||
t.Errorf("non-zero return code when running proxy: %d", code)
|
||||
t.Logf("STDOUT from proxy:\n%s", ui.OutputWriter.String())
|
||||
t.Logf("STDERR from proxy:\n%s", ui.ErrorWriter.String())
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
@@ -982,11 +970,9 @@ cache {}
|
||||
cmd.startedCh = make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
cmd.Run([]string{"-config", configPath})
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-cmd.startedCh:
|
||||
|
||||
@@ -226,8 +226,8 @@ func TestSecretsEnableCommand_Run(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
modLines := strings.Split(string(modFile), "\n")
|
||||
for _, p := range modLines {
|
||||
modLines := strings.SplitSeq(string(modFile), "\n")
|
||||
for p := range modLines {
|
||||
splitLine := strings.Split(strings.TrimSpace(p), " ")
|
||||
if len(splitLine) == 0 {
|
||||
continue
|
||||
|
||||
@@ -71,8 +71,7 @@ func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (str
|
||||
}
|
||||
|
||||
// Only add our hostname to SANs if it isn't found.
|
||||
foundHostname := slices.Contains(template.DNSNames, hostname)
|
||||
if !foundHostname {
|
||||
if !slices.Contains(template.DNSNames, hostname) {
|
||||
template.DNSNames = append(template.DNSNames, hostname)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,16 +7,11 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
slices0 "slices"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
credUserpass "github.com/openbao/openbao/builtin/credential/userpass"
|
||||
dbMysql "github.com/openbao/openbao/plugins/database/mysql"
|
||||
"github.com/openbao/openbao/sdk/v2/helper/consts"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// Test_RegistryGet exercises the (registry).Get functionality by comparing
|
||||
@@ -26,28 +21,24 @@ func Test_RegistryGet(t *testing.T) {
|
||||
name string
|
||||
builtin string
|
||||
pluginType consts.PluginType
|
||||
want BuiltinFactory
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "non-existent builtin",
|
||||
builtin: "foo",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "bad plugin type",
|
||||
builtin: "app-id",
|
||||
pluginType: 9000,
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "known builtin lookup",
|
||||
builtin: "userpass",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: toFunc(credUserpass.Factory),
|
||||
wantOk: true,
|
||||
},
|
||||
// The app-id plugin has been fully removed from OpenBao.
|
||||
@@ -55,27 +46,18 @@ func Test_RegistryGet(t *testing.T) {
|
||||
name: "removed builtin lookup",
|
||||
builtin: "app-id",
|
||||
pluginType: consts.PluginTypeCredential,
|
||||
want: nil,
|
||||
wantOk: true,
|
||||
},*/
|
||||
{
|
||||
name: "known builtin lookup",
|
||||
builtin: "mysql-database-plugin",
|
||||
pluginType: consts.PluginTypeDatabase,
|
||||
want: dbMysql.New(dbMysql.DefaultUserNameTemplate),
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got BuiltinFactory
|
||||
got, ok := Registry.Get(tt.builtin, tt.pluginType)
|
||||
if ok {
|
||||
if reflect.TypeOf(got) != reflect.TypeOf(tt.want) {
|
||||
t.Fatalf("got type: %T, want type: %T", got, tt.want)
|
||||
}
|
||||
}
|
||||
if tt.wantOk != ok {
|
||||
if _, ok := Registry.Get(tt.builtin, tt.pluginType); tt.wantOk != ok {
|
||||
t.Fatalf("error: got %v, want %v", ok, tt.wantOk)
|
||||
}
|
||||
})
|
||||
@@ -296,7 +278,7 @@ func Test_RegistryMatchesGenOpenapi(t *testing.T) {
|
||||
ensureInScript := func(t *testing.T, scriptBackends []string, name string) {
|
||||
t.Helper()
|
||||
|
||||
if slices0.Contains([]string{
|
||||
if slices.Contains([]string{
|
||||
"oidc",
|
||||
"openldap",
|
||||
}, name) {
|
||||
|
||||
@@ -74,7 +74,7 @@ func (c CharsetRule) Pass(value []rune) bool {
|
||||
// charIn is sometimes faster than a map lookup because the data is so small
|
||||
// This is being kept rather than converted to a map to keep the code cleaner,
|
||||
// otherwise there would need to be additional parsing logic.
|
||||
if charIn(r, c.Charset) {
|
||||
if slices.Contains(c.Charset, r) {
|
||||
count++
|
||||
if count >= c.MinChars {
|
||||
return true
|
||||
@@ -84,7 +84,3 @@ func (c CharsetRule) Pass(value []rune) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func charIn(search rune, charset []rune) bool {
|
||||
return slices.Contains(charset, search)
|
||||
}
|
||||
|
||||
+2
-2
@@ -533,8 +533,8 @@ func WrapForwardedForHandler(h http.Handler, l *configutil.Listener) http.Handle
|
||||
// to the multiple-header case.
|
||||
var acc []string
|
||||
for _, header := range headers {
|
||||
vals := strings.Split(header, ",")
|
||||
for _, v := range vals {
|
||||
vals := strings.SplitSeq(header, ",")
|
||||
for v := range vals {
|
||||
acc = append(acc, strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
|
||||
+4
-3
@@ -10,6 +10,7 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -136,11 +137,11 @@ func makeLongEmptyList(size int) interface{} {
|
||||
}
|
||||
|
||||
func makeLongString(size int) interface{} {
|
||||
var x string
|
||||
var x strings.Builder
|
||||
for i := range size {
|
||||
x += fmt.Sprintf("%d", i%10)
|
||||
x.WriteString(fmt.Sprintf("%d", i%10))
|
||||
}
|
||||
return x
|
||||
return x.String()
|
||||
}
|
||||
|
||||
func makeLargeMap(size int) interface{} {
|
||||
|
||||
@@ -332,16 +332,14 @@ func (l *raftLayer) Handoff(ctx context.Context, wg *sync.WaitGroup, quit chan s
|
||||
return errors.New("raft is shutdown")
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
select {
|
||||
case l.connCh <- conn:
|
||||
case <-l.closeCh:
|
||||
case <-ctx.Done():
|
||||
case <-quit:
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -447,7 +447,7 @@ func assertAllFieldsSetValue(name string, rVal reflect.Value) error {
|
||||
}
|
||||
|
||||
switch rVal.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
return assertAllFieldsSetValue(name, rVal.Elem())
|
||||
case reflect.Struct:
|
||||
return assertAllFieldsSetStruct(name, rVal)
|
||||
@@ -483,7 +483,7 @@ func assertAllFieldsSetValue(name string, rVal reflect.Value) error {
|
||||
|
||||
func assertAllFieldsSetStruct(name string, rVal reflect.Value) error {
|
||||
switch rVal.Type() {
|
||||
case reflect.TypeOf(timestamppb.Timestamp{}):
|
||||
case reflect.TypeFor[timestamppb.Timestamp]():
|
||||
ts := rVal.Interface().(timestamppb.Timestamp)
|
||||
if ts.AsTime().IsZero() {
|
||||
return fmt.Errorf("%s is zero", name)
|
||||
|
||||
@@ -1041,7 +1041,7 @@ func hyphenatedToTitleCase(in string) string {
|
||||
|
||||
title := cases.Title(language.English, cases.NoLower)
|
||||
|
||||
for _, word := range strings.Split(in, "-") {
|
||||
for word := range strings.SplitSeq(in, "-") {
|
||||
b.WriteString(title.String(word))
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ func (p *PolicyMap) Policies(ctx context.Context, s logical.Storage, names ...st
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range strings.Split(values, ",") {
|
||||
for p := range strings.SplitSeq(values, ",") {
|
||||
if p = strings.TrimSpace(p); p != "" {
|
||||
set[p] = struct{}{}
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func AddIdentity(view logical.SystemView, req *logical.Request, data map[string]
|
||||
|
||||
func encodeJSON(value ref.Val) ref.Val {
|
||||
native, err := value.ConvertToNative(
|
||||
reflect.TypeOf(map[string]any{}),
|
||||
reflect.TypeFor[map[string]any](),
|
||||
)
|
||||
if err != nil {
|
||||
return types.Bool(false)
|
||||
|
||||
@@ -112,8 +112,8 @@ func ParseHexFormatted(in, sep string) []byte {
|
||||
var ret bytes.Buffer
|
||||
var err error
|
||||
var inBits uint64
|
||||
inBytes := strings.Split(in, sep)
|
||||
for _, inByte := range inBytes {
|
||||
inBytes := strings.SplitSeq(in, sep)
|
||||
for inByte := range inBytes {
|
||||
if inBits, err = strconv.ParseUint(inByte, 16, 8); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ type Client struct {
|
||||
func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) {
|
||||
var retErr *multierror.Error
|
||||
var conn Connection
|
||||
urls := strings.Split(cfg.Url, ",")
|
||||
urls := strings.SplitSeq(cfg.Url, ",")
|
||||
|
||||
for _, uut := range urls {
|
||||
for uut := range urls {
|
||||
u, err := url.Parse(uut)
|
||||
if err != nil {
|
||||
retErr = multierror.Append(retErr, fmt.Errorf("error parsing url %q: %w", uut, err))
|
||||
@@ -468,12 +468,13 @@ func sidBytesToString(b []byte) (string, error) {
|
||||
return "", fmt.Errorf("SID %#v convert failed reading SubAuthority: %w", b, err)
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority)
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("S-%d-%d", revision, identifierAuthority))
|
||||
for _, subAuthorityPart := range subAuthority {
|
||||
result += fmt.Sprintf("-%d", subAuthorityPart)
|
||||
result.WriteString(fmt.Sprintf("-%d", subAuthorityPart))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection, userDN string) ([]*ldap.Entry, error) {
|
||||
@@ -505,10 +506,7 @@ func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection,
|
||||
groupEntries := make([]*ldap.Entry, 0, len(groupAttrValues))
|
||||
|
||||
for range maxWorkers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
for sid := range taskChan {
|
||||
groupResult, err := conn.Search(&ldap.SearchRequest{
|
||||
BaseDN: fmt.Sprintf("<SID=%s>", sid),
|
||||
@@ -534,7 +532,7 @@ func (c *Client) performLdapTokenGroupsSearch(cfg *ConfigEntry, conn Connection,
|
||||
groupEntries = append(groupEntries, groupResult.Entries[0])
|
||||
lock.Unlock()
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
for _, sidBytes := range groupAttrValues {
|
||||
|
||||
@@ -589,7 +589,7 @@ func (c *Client) VerifyPeerCertificate(ctx context.Context, verifiedChains [][]*
|
||||
}
|
||||
|
||||
func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, conf *VerifyConfig) *ocspStatus {
|
||||
msg := ""
|
||||
var msg strings.Builder
|
||||
if conf.OcspFailureMode == FailOpenFalse {
|
||||
// Fail closed. any error is returned to stop connection
|
||||
for _, r := range results {
|
||||
@@ -611,13 +611,13 @@ func (c *Client) canEarlyExitForOCSP(results []*ocspStatus, chainSize int, conf
|
||||
return r
|
||||
}
|
||||
if r != nil && r.code != ocspStatusGood && r.err != nil {
|
||||
msg += "" + r.err.Error()
|
||||
msg.WriteString("" + r.err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(msg) > 0 {
|
||||
if len(msg.String()) > 0 {
|
||||
c.Logger().Warn(
|
||||
"OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg)
|
||||
"OCSP is set to fail-open, and could not retrieve OCSP based revocation checking but proceeding.", "detail", msg.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,10 +49,8 @@ func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) {
|
||||
r.setInitialState(shutdownCh)
|
||||
|
||||
// Run this in a go func so this call doesn't block.
|
||||
wait.Add(1)
|
||||
go func() {
|
||||
wait.Go(func() {
|
||||
// Make sure Vault will give us time to finish up here.
|
||||
defer wait.Done()
|
||||
|
||||
var g run.Group
|
||||
|
||||
@@ -80,7 +78,7 @@ func (r *retryHandler) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) {
|
||||
if err := g.Run(); err != nil {
|
||||
r.logger.Error("error encountered during periodic state update", "error", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *retryHandler) setInitialState(shutdownCh <-chan struct{}) {
|
||||
|
||||
@@ -28,9 +28,7 @@ func TestInmemCluster_Connect(t *testing.T) {
|
||||
var accepted int
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
@@ -48,7 +46,7 @@ func TestInmemCluster_Connect(t *testing.T) {
|
||||
accepted++
|
||||
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Make sure two nodes can connect in
|
||||
conn, err := cluster.layers[1].DialContext(t.Context(), server.addr, nil)
|
||||
@@ -94,9 +92,7 @@ func TestInmemCluster_Disconnect(t *testing.T) {
|
||||
var accepted int
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
@@ -114,7 +110,7 @@ func TestInmemCluster_Disconnect(t *testing.T) {
|
||||
accepted++
|
||||
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Make sure node1 cannot connect in
|
||||
conn, err := cluster.layers[1].DialContext(t.Context(), server.addr, nil)
|
||||
@@ -201,9 +197,7 @@ func TestInmemCluster_ConnectCluster(t *testing.T) {
|
||||
stopCh := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
acceptConns := func(listener NetworkListener) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
@@ -221,7 +215,7 @@ func TestInmemCluster_ConnectCluster(t *testing.T) {
|
||||
accepted.Add(1)
|
||||
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Start a listener on each node.
|
||||
|
||||
@@ -35,23 +35,19 @@ func InduceDeadlock(t *testing.T, vaultcore *Core, expected uint32) {
|
||||
}
|
||||
var mtx deadlock.Mutex
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
vaultcore.expiration.coreStateLock.Lock()
|
||||
mtx.Lock()
|
||||
mtx.Unlock() //nolint:staticcheck
|
||||
vaultcore.expiration.coreStateLock.Unlock()
|
||||
}()
|
||||
})
|
||||
wg.Wait()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
mtx.Lock()
|
||||
vaultcore.expiration.coreStateLock.RLock()
|
||||
vaultcore.expiration.coreStateLock.RUnlock() //nolint:staticcheck
|
||||
mtx.Unlock()
|
||||
}()
|
||||
})
|
||||
wg.Wait()
|
||||
if deadlocks.Load() != expected {
|
||||
t.Fatalf("expected 1 deadlock, detected %d", deadlocks.Load())
|
||||
|
||||
@@ -54,10 +54,7 @@ func TestPostgreSQL_FencedWrites(t *testing.T) {
|
||||
var logs []string
|
||||
var wg sync.WaitGroup
|
||||
for range 10 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
var localLogs []string
|
||||
|
||||
// 5 iterations is roughly 2.5 seconds with the 5ms sleep.
|
||||
@@ -99,7 +96,7 @@ func TestPostgreSQL_FencedWrites(t *testing.T) {
|
||||
seenLogs[log] = struct{}{}
|
||||
}
|
||||
logLock.Unlock()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Now sacrifice the leader's lock and ensure it doesn't write.
|
||||
|
||||
+2
-4
@@ -57,14 +57,12 @@ func TestGrabLockOrStop(t *testing.T) {
|
||||
// closerWg waits until the closer goroutine exits before we do
|
||||
// another iteration. This makes sure goroutines don't pile up.
|
||||
var closerWg sync.WaitGroup
|
||||
closerWg.Add(1)
|
||||
go func() {
|
||||
defer closerWg.Done()
|
||||
closerWg.Go(func() {
|
||||
// Close the stop channel half the time.
|
||||
if rand.Int()%2 == 0 {
|
||||
close(stop)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Half the goroutines lock/unlock and the other half rlock/runlock.
|
||||
if g%2 == 0 {
|
||||
|
||||
@@ -829,7 +829,7 @@ func parsePingIDConfig(mConfig *mfa.Config, d *framework.FieldData) error {
|
||||
}
|
||||
|
||||
config := &mfa.PingIDConfig{}
|
||||
for _, line := range strings.Split(string(fileBytes), "\n") {
|
||||
for line := range strings.SplitSeq(string(fileBytes), "\n") {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -174,10 +174,7 @@ func (i *IdentityStore) LoadEntities(ctx context.Context, readOnly bool) error {
|
||||
|
||||
// Create 64 workers to distribute work to
|
||||
for range consts.ExpirationRestoreWorkerCount {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case key, ok := <-broker:
|
||||
@@ -200,13 +197,11 @@ func (i *IdentityStore) LoadEntities(ctx context.Context, readOnly bool) error {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
// Distribute the collected keys to the workers in a go routine
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
for j, key := range existing {
|
||||
if j%500 == 0 {
|
||||
i.logger.Debug("entities loading", "progress", j)
|
||||
@@ -223,7 +218,7 @@ func (i *IdentityStore) LoadEntities(ctx context.Context, readOnly bool) error {
|
||||
|
||||
// Close the broker, causing worker routines to exit
|
||||
close(broker)
|
||||
}()
|
||||
})
|
||||
|
||||
// Restore each key by pulling from the result chan
|
||||
LOOP:
|
||||
|
||||
@@ -6,6 +6,7 @@ package vault
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/sha3"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
@@ -48,7 +49,6 @@ import (
|
||||
"github.com/openbao/openbao/sdk/v2/logical"
|
||||
"github.com/openbao/openbao/vault/routing"
|
||||
"github.com/openbao/openbao/version"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
const maxBytes = 128 * 1024
|
||||
|
||||
Reference in New Issue
Block a user