Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 47 additions & 45 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package client
import (
_context "context"
"encoding/json"
"errors"
"fmt"
"math"
_nethttp "net/http"
"time"

"github.com/sourcegraph/conc/pool"
"golang.org/x/sync/errgroup"

fgaSdk "github.com/openfga/go-sdk"
"github.com/openfga/go-sdk/credentials"
Expand Down Expand Up @@ -1773,18 +1773,14 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
if request.GetBody() != nil {
for i := 0; i < len(request.GetBody().Writes); i += writeChunkSize {
end := int(math.Min(float64(i+writeChunkSize), float64(len(request.GetBody().Writes))))

writeChunks = append(writeChunks, (request.GetBody().Writes)[i:end])
}
}

writeGroup, ctx := errgroup.WithContext(request.GetContext())

writeGroup.SetLimit(int(maxParallelReqs))
writeResponses := make([]ClientWriteResponse, len(writeChunks))
for index, writeBody := range writeChunks {
index, writeBody := index, writeBody
writeGroup.Go(func() error {
writePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs))
for _, writeBody := range writeChunks {
writeBody := writeBody
writePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) {
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
Expand All @@ -1798,19 +1794,16 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
Conflict: options.Conflict,
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
var authErr fgaSdk.FgaApiAuthenticationError
// If an error was returned then it will be an authentication error so we want to return
if errors.As(err, &authErr) {
return nil, err
}

writeResponses[index] = *singleResponse

return nil
return singleResponse, nil
})
}

err = writeGroup.Wait()
// If an error was returned then it will be an authentication error so we want to return
writeResponses, err := writePool.Wait()
if err != nil {
return &response, err
}
Expand All @@ -1825,12 +1818,10 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
}
}

deleteGroup, ctx := errgroup.WithContext(request.GetContext())
deleteGroup.SetLimit(int(maxParallelReqs))
deleteResponses := make([]ClientWriteResponse, len(deleteChunks))
for index, deleteBody := range deleteChunks {
index, deleteBody := index, deleteBody
deleteGroup.Go(func() error {
deletePool := pool.NewWithResults[*ClientWriteResponse]().WithContext(request.GetContext()).WithMaxGoroutines(int(maxParallelReqs))
for _, deleteBody := range deleteChunks {
deleteBody := deleteBody
deletePool.Go(func(ctx _context.Context) (*ClientWriteResponse, error) {
singleResponse, err := client.WriteExecute(&SdkClientWriteRequest{
ctx: ctx,
Client: client,
Expand All @@ -1845,19 +1836,17 @@ func (client *OpenFgaClient) WriteExecute(request SdkClientWriteRequestInterface
},
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
var authErr fgaSdk.FgaApiAuthenticationError
if errors.As(err, &authErr) {
return nil, err
}

deleteResponses[index] = *singleResponse

return nil
return singleResponse, nil
})
}

err = deleteGroup.Wait()
deleteResponses, err := deletePool.Wait()
// If an error was returned then it will be an authentication error so we want to return
if err != nil {
// If an error was returned then it will be an authentication error so we want to return
return &response, err
}

Expand Down Expand Up @@ -2225,7 +2214,7 @@ func (request *SdkClientBatchCheckClientRequest) GetOptions() *ClientBatchCheckC
}

func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheckClientRequestInterface) (*ClientBatchCheckClientResponse, error) {
group, ctx := errgroup.WithContext(request.GetContext())
ctx := request.GetContext()
requestOptions := RequestOptions{}
maxParallelReqs := int(DEFAULT_MAX_METHOD_PARALLEL_REQS)
if request.GetOptions() != nil {
Expand All @@ -2235,7 +2224,6 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck
}
}

group.SetLimit(maxParallelReqs)
var numOfChecks = len(*request.GetBody())
response := make(ClientBatchCheckClientResponse, numOfChecks)
authorizationModelId, err := client.getAuthorizationModelId(request.GetAuthorizationModelIdOverride())
Expand All @@ -2259,34 +2247,48 @@ func (client *OpenFgaClient) ClientBatchCheckExecute(request SdkClientBatchCheck
checkOptions.Consistency = request.GetOptions().Consistency
}

type batchCheckResult struct {
Index int
Response ClientBatchCheckClientSingleResponse
}

checkPool := pool.NewWithResults[*batchCheckResult]().WithContext(ctx).WithMaxGoroutines(maxParallelReqs)
for index, checkBody := range *request.GetBody() {
index, checkBody := index, checkBody
group.Go(func() error {
checkPool.Go(func(ctx _context.Context) (*batchCheckResult, error) {
singleResponse, err := client.CheckExecute(&SdkClientCheckRequest{
ctx: ctx,
Client: client,
body: &checkBody,
options: checkOptions,
})

if _, ok := err.(fgaSdk.FgaApiAuthenticationError); ok {
return err
}

response[index] = ClientBatchCheckClientSingleResponse{
Request: checkBody,
ClientCheckResponse: *singleResponse,
Error: err,
var authErr fgaSdk.FgaApiAuthenticationError
// If an error was returned then it will be an authentication error so we want to return
if errors.As(err, &authErr) {
return nil, err
}

return nil
return &batchCheckResult{
Index: index,
Response: ClientBatchCheckClientSingleResponse{
Request: checkBody,
ClientCheckResponse: *singleResponse,
Error: err,
},
}, nil
})
}

if err := group.Wait(); err != nil {
results, err := checkPool.Wait()
if err != nil {
return nil, err
}

for _, result := range results {
response[result.Index] = result.Response
}

return &response, nil
}

Expand Down
6 changes: 4 additions & 2 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -1764,7 +1765,8 @@ func TestOpenFgaClient(t *testing.T) {
t.Fatalf("Expect error with invalid auth but there is none")
}

if _, ok := err.(openfga.FgaApiAuthenticationError); !ok {
var authErr openfga.FgaApiAuthenticationError
if !errors.As(err, &authErr) {
t.Fatalf("Expected an api auth error")
}

Expand All @@ -1782,7 +1784,7 @@ func TestOpenFgaClient(t *testing.T) {
t.Fatalf("Expect error with invalid auth but there is none")
}

if _, ok := err.(openfga.FgaApiAuthenticationError); !ok {
if !errors.As(err, &authErr) {
t.Fatalf("Expected an api auth error")
}
})
Expand Down
Loading