|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "flag" |
| 7 | + "fmt" |
| 8 | + "log" |
| 9 | + "math" |
| 10 | + "os" |
| 11 | + "os/signal" |
| 12 | + "path/filepath" |
| 13 | + "sync" |
| 14 | + "sync/atomic" |
| 15 | + "time" |
| 16 | + |
| 17 | + "golang.org/x/time/rate" |
| 18 | + |
| 19 | + "github.com/tidbcloud/tidbcloud-cli/internal" |
| 20 | + "github.com/tidbcloud/tidbcloud-cli/internal/config" |
| 21 | + "github.com/tidbcloud/tidbcloud-cli/internal/config/store" |
| 22 | + "github.com/tidbcloud/tidbcloud-cli/internal/iostream" |
| 23 | + "github.com/tidbcloud/tidbcloud-cli/internal/prop" |
| 24 | + "github.com/tidbcloud/tidbcloud-cli/internal/service/aws/s3" |
| 25 | + "github.com/tidbcloud/tidbcloud-cli/internal/service/cloud" |
| 26 | + "github.com/tidbcloud/tidbcloud-cli/pkg/tidbcloud/v1beta1/serverless/cluster" |
| 27 | + |
| 28 | + "github.com/spf13/viper" |
| 29 | + "github.com/zalando/go-keyring" |
| 30 | +) |
| 31 | + |
| 32 | +const ( |
| 33 | + defaultProjectID = "1369847559694040868" |
| 34 | + defaultRegion = "regions/aws-us-east-1" |
| 35 | + defaultNamePrefix = "keep--1h" |
| 36 | + defaultSpendingLimit = 10 |
| 37 | + defaultConcurrency = 5 |
| 38 | + defaultTotal = 100 |
| 39 | + defaultRPS = 2.0 |
| 40 | + waitInterval = 2 * time.Second |
| 41 | + waitTimeout = 10 * time.Minute |
| 42 | +) |
| 43 | + |
| 44 | +type benchConfig struct { |
| 45 | + concurrency int |
| 46 | + rps float64 |
| 47 | + total int |
| 48 | + projectID string |
| 49 | + region string |
| 50 | + namePrefix string |
| 51 | + spendingLimit int |
| 52 | + minRcu int |
| 53 | + maxRcu int |
| 54 | + encryption bool |
| 55 | + disablePub bool |
| 56 | + waitReady bool |
| 57 | +} |
| 58 | + |
| 59 | +func main() { |
| 60 | + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) |
| 61 | + defer stop() |
| 62 | + |
| 63 | + initBenchConfig() |
| 64 | + config.SetActiveProfile(viper.GetString(prop.CurProfile)) |
| 65 | + |
| 66 | + cfg := parseFlags() |
| 67 | + h := newHelper() |
| 68 | + |
| 69 | + client, err := h.Client() |
| 70 | + if err != nil { |
| 71 | + log.Fatalf("init client: %v", err) |
| 72 | + } |
| 73 | + |
| 74 | + runBench(ctx, client, cfg) |
| 75 | +} |
| 76 | + |
| 77 | +func parseFlags() benchConfig { |
| 78 | + cfg := benchConfig{ |
| 79 | + concurrency: defaultConcurrency, |
| 80 | + rps: defaultRPS, |
| 81 | + total: defaultTotal, |
| 82 | + projectID: defaultProjectID, |
| 83 | + region: defaultRegion, |
| 84 | + namePrefix: defaultNamePrefix, |
| 85 | + spendingLimit: defaultSpendingLimit, |
| 86 | + } |
| 87 | + |
| 88 | + flag.IntVar(&cfg.concurrency, "concurrency", cfg.concurrency, "number of concurrent workers") |
| 89 | + flag.Float64Var(&cfg.rps, "rps", cfg.rps, "requests per second") |
| 90 | + flag.IntVar(&cfg.total, "total", cfg.total, "total number of clusters to create") |
| 91 | + flag.StringVar(&cfg.projectID, "project-id", cfg.projectID, "project id") |
| 92 | + flag.StringVar(&cfg.region, "region", cfg.region, "region name") |
| 93 | + flag.StringVar(&cfg.namePrefix, "name-prefix", cfg.namePrefix, "prefix of the cluster name") |
| 94 | + flag.IntVar(&cfg.spendingLimit, "spending-limit", cfg.spendingLimit, "monthly spending limit in USD cents, Starter only") |
| 95 | + flag.IntVar(&cfg.minRcu, "min-rcu", 0, "minimum RCU, Essential only") |
| 96 | + flag.IntVar(&cfg.maxRcu, "max-rcu", 0, "maximum RCU, Essential only") |
| 97 | + flag.BoolVar(&cfg.encryption, "encryption", false, "enable enhanced encryption") |
| 98 | + flag.BoolVar(&cfg.disablePub, "disable-public-endpoint", false, "disable public endpoint") |
| 99 | + flag.BoolVar(&cfg.waitReady, "wait-ready", true, "wait for cluster to be ACTIVE") |
| 100 | + flag.Parse() |
| 101 | + |
| 102 | + if cfg.total <= 0 { |
| 103 | + log.Fatalf("total must be positive") |
| 104 | + } |
| 105 | + |
| 106 | + if cfg.concurrency <= 0 { |
| 107 | + log.Fatalf("concurrency must be positive") |
| 108 | + } |
| 109 | + |
| 110 | + if cfg.rps <= 0 { |
| 111 | + log.Fatalf("rps must be positive") |
| 112 | + } |
| 113 | + |
| 114 | + if (cfg.minRcu > 0 || cfg.maxRcu > 0) && cfg.minRcu > cfg.maxRcu { |
| 115 | + log.Fatalf("min-rcu cannot exceed max-rcu") |
| 116 | + } |
| 117 | + |
| 118 | + return cfg |
| 119 | +} |
| 120 | + |
| 121 | +func initBenchConfig() { |
| 122 | + home, err := os.UserHomeDir() |
| 123 | + if err != nil { |
| 124 | + log.Fatalf("get home: %v", err) |
| 125 | + } |
| 126 | + path := filepath.Join(home, config.HomePath) |
| 127 | + if err := os.MkdirAll(path, 0700); err != nil { |
| 128 | + log.Fatalf("init config dir: %v", err) |
| 129 | + } |
| 130 | + |
| 131 | + viper.AddConfigPath(path) |
| 132 | + viper.SetConfigType("toml") |
| 133 | + viper.SetConfigName("config") |
| 134 | + viper.SetConfigPermissions(0600) |
| 135 | + if err := viper.SafeWriteConfig(); err != nil { |
| 136 | + var existErr viper.ConfigFileAlreadyExistsError |
| 137 | + if !errors.As(err, &existErr) { |
| 138 | + log.Fatalf("write config: %v", err) |
| 139 | + } |
| 140 | + } |
| 141 | + if err := viper.ReadInConfig(); err != nil { |
| 142 | + log.Fatalf("read config: %v", err) |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +func newHelper() *internal.Helper { |
| 147 | + return &internal.Helper{ |
| 148 | + Client: func() (cloud.TiDBCloudClient, error) { |
| 149 | + publicKey, privateKey := config.GetPublicKey(), config.GetPrivateKey() |
| 150 | + serverlessEndpoint := config.GetServerlessEndpoint() |
| 151 | + if serverlessEndpoint == "" { |
| 152 | + serverlessEndpoint = cloud.DefaultServerlessEndpoint |
| 153 | + } |
| 154 | + iamEndpoint := config.GetIAMEndpoint() |
| 155 | + if iamEndpoint == "" { |
| 156 | + iamEndpoint = cloud.DefaultIAMEndpoint |
| 157 | + } |
| 158 | + |
| 159 | + if publicKey != "" && privateKey != "" { |
| 160 | + return cloud.NewClientDelegateWithApiKey(publicKey, privateKey, serverlessEndpoint, iamEndpoint) |
| 161 | + } |
| 162 | + |
| 163 | + if err := config.ValidateToken(); err != nil { |
| 164 | + return nil, err |
| 165 | + } |
| 166 | + token, err := config.GetAccessToken() |
| 167 | + if err != nil { |
| 168 | + if errors.Is(err, keyring.ErrNotFound) || errors.Is(err, store.ErrNotSupported) { |
| 169 | + return nil, err |
| 170 | + } |
| 171 | + return nil, err |
| 172 | + } |
| 173 | + return cloud.NewClientDelegateWithToken(token, serverlessEndpoint, iamEndpoint) |
| 174 | + }, |
| 175 | + Uploader: func(client cloud.TiDBCloudClient) s3.Uploader { |
| 176 | + return s3.NewUploader(client) |
| 177 | + }, |
| 178 | + QueryPageSize: internal.DefaultPageSize, |
| 179 | + IOStreams: iostream.System(), |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +func runBench(ctx context.Context, client cloud.TiDBCloudClient, cfg benchConfig) { |
| 184 | + limiter := rate.NewLimiter(rate.Limit(cfg.rps), int(math.Ceil(cfg.rps))) |
| 185 | + jobs := make(chan int, cfg.total) |
| 186 | + |
| 187 | + var success int64 |
| 188 | + var failed int64 |
| 189 | + |
| 190 | + var wg sync.WaitGroup |
| 191 | + |
| 192 | + timestamp := time.Now().Unix() |
| 193 | + for i := 0; i < cfg.concurrency; i++ { |
| 194 | + wg.Add(1) |
| 195 | + go func(worker int) { |
| 196 | + defer wg.Done() |
| 197 | + for idx := range jobs { |
| 198 | + if err := limiter.Wait(ctx); err != nil { |
| 199 | + log.Printf("worker %d rate wait err: %v", worker, err) |
| 200 | + continue |
| 201 | + } |
| 202 | + name := fmt.Sprintf("%s-%d-%d", cfg.namePrefix, timestamp, idx) |
| 203 | + start := time.Now() |
| 204 | + id, err := createOnce(ctx, client, cfg, name) |
| 205 | + if err != nil { |
| 206 | + atomic.AddInt64(&failed, 1) |
| 207 | + log.Printf("worker %d create %s failed: %v", worker, name, err) |
| 208 | + continue |
| 209 | + } |
| 210 | + |
| 211 | + if cfg.waitReady { |
| 212 | + if err := waitClusterReady(ctx, client, id); err != nil { |
| 213 | + atomic.AddInt64(&failed, 1) |
| 214 | + log.Printf("worker %d wait %s failed: %v", worker, id, err) |
| 215 | + continue |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + atomic.AddInt64(&success, 1) |
| 220 | + log.Printf("worker %d create %s (id=%s) ok in %s", worker, name, id, time.Since(start)) |
| 221 | + } |
| 222 | + }(i) |
| 223 | + } |
| 224 | + |
| 225 | + for i := 0; i < cfg.total; i++ { |
| 226 | + jobs <- i |
| 227 | + } |
| 228 | + close(jobs) |
| 229 | + |
| 230 | + wg.Wait() |
| 231 | + log.Printf("bench done: success=%d failed=%d", success, failed) |
| 232 | +} |
| 233 | + |
| 234 | +func createOnce(ctx context.Context, client cloud.TiDBCloudClient, cfg benchConfig, name string) (string, error) { |
| 235 | + payload := &cluster.TidbCloudOpenApiserverlessv1beta1Cluster{ |
| 236 | + DisplayName: name, |
| 237 | + Region: cluster.Commonv1beta1Region{ |
| 238 | + Name: &cfg.region, |
| 239 | + }, |
| 240 | + } |
| 241 | + |
| 242 | + if cfg.projectID != "" { |
| 243 | + payload.Labels = &map[string]string{"tidb.cloud/project": cfg.projectID} |
| 244 | + } |
| 245 | + if cfg.spendingLimit > 0 { |
| 246 | + payload.SpendingLimit = &cluster.ClusterSpendingLimit{ |
| 247 | + Monthly: toInt32Ptr(int32(cfg.spendingLimit)), |
| 248 | + } |
| 249 | + } |
| 250 | + if cfg.minRcu > 0 || cfg.maxRcu > 0 { |
| 251 | + payload.AutoScaling = &cluster.V1beta1ClusterAutoScaling{ |
| 252 | + MinRcu: toInt64Ptr(int64(cfg.minRcu)), |
| 253 | + MaxRcu: toInt64Ptr(int64(cfg.maxRcu)), |
| 254 | + } |
| 255 | + } |
| 256 | + if cfg.encryption { |
| 257 | + payload.EncryptionConfig = &cluster.V1beta1ClusterEncryptionConfig{ |
| 258 | + EnhancedEncryptionEnabled: &cfg.encryption, |
| 259 | + } |
| 260 | + } |
| 261 | + if cfg.disablePub { |
| 262 | + payload.Endpoints = &cluster.V1beta1ClusterEndpoints{ |
| 263 | + Public: &cluster.EndpointsPublic{ |
| 264 | + Disabled: &cfg.disablePub, |
| 265 | + }, |
| 266 | + } |
| 267 | + } |
| 268 | + |
| 269 | + resp, err := client.CreateCluster(ctx, payload) |
| 270 | + if err != nil { |
| 271 | + return "", err |
| 272 | + } |
| 273 | + if resp.ClusterId == nil { |
| 274 | + return "", fmt.Errorf("empty cluster id") |
| 275 | + } |
| 276 | + return *resp.ClusterId, nil |
| 277 | +} |
| 278 | + |
| 279 | +func waitClusterReady(ctx context.Context, client cloud.TiDBCloudClient, clusterID string) error { |
| 280 | + ticker := time.NewTicker(waitInterval) |
| 281 | + defer ticker.Stop() |
| 282 | + timer := time.After(waitTimeout) |
| 283 | + |
| 284 | + for { |
| 285 | + select { |
| 286 | + case <-ctx.Done(): |
| 287 | + return ctx.Err() |
| 288 | + case <-timer: |
| 289 | + return fmt.Errorf("timeout waiting for cluster %s ready", clusterID) |
| 290 | + case <-ticker.C: |
| 291 | + c, err := client.GetCluster(ctx, clusterID, cluster.CLUSTERSERVICEGETCLUSTERVIEWPARAMETER_BASIC) |
| 292 | + if err != nil { |
| 293 | + return err |
| 294 | + } |
| 295 | + if c.State != nil && *c.State == cluster.COMMONV1BETA1CLUSTERSTATE_ACTIVE { |
| 296 | + return nil |
| 297 | + } |
| 298 | + } |
| 299 | + } |
| 300 | +} |
| 301 | + |
| 302 | +func toInt32Ptr(v int32) *int32 { |
| 303 | + return &v |
| 304 | +} |
| 305 | + |
| 306 | +func toInt64Ptr(v int64) *int64 { |
| 307 | + val := int64(v) |
| 308 | + return &val |
| 309 | +} |
0 commit comments