Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions flagcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,13 @@ func CheckFlag(
company *Company,
user *User,
flag *Flag,
opts ...CheckFlagOption,
) (*CheckFlagResult, error) {
options := &checkFlagOptions{}
for _, opt := range opts {
opt(options)
}

resp := &CheckFlagResult{Reason: ReasonNoRulesMatched}

if flag == nil {
Expand Down Expand Up @@ -154,9 +160,11 @@ func CheckFlag(
}

checkRuleResp, err := ruleChecker.Check(ctx, &CheckScope{
Company: company,
Rule: rule,
User: user,
Company: company,
Rule: rule,
User: user,
Usage: options.usage,
EventUsage: options.eventUsage,
})
if err != nil {
resp.Err = err
Expand Down
18 changes: 14 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
)

type WasmInput struct {
Company Company `json:"company"`
User User `json:"user"`
Flag Flag `json:"flag"`
Company Company `json:"company"`
User User `json:"user"`
Flag Flag `json:"flag"`
Usage *int64 `json:"usage,omitempty"`
EventUsage map[string]int64 `json:"event_usage,omitempty"`
}

type WasmOutput struct {
Expand All @@ -26,7 +28,15 @@ func checkFlag(this js.Value, args []js.Value) interface{} {
var wasmInput WasmInput
json.Unmarshal([]byte(input), &wasmInput)

result, err := CheckFlag(context.Background(), &wasmInput.Company, &wasmInput.User, &wasmInput.Flag)
var opts []CheckFlagOption
if wasmInput.Usage != nil {
opts = append(opts, WithUsage(*wasmInput.Usage))
}
for eventSubtype, quantity := range wasmInput.EventUsage {
opts = append(opts, WithEventUsage(eventSubtype, quantity))
}

result, err := CheckFlag(context.Background(), &wasmInput.Company, &wasmInput.User, &wasmInput.Flag, opts...)

output := WasmOutput{
Result: result,
Expand Down
28 changes: 28 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package rulesengine

type CheckFlagOption func(*checkFlagOptions)

type checkFlagOptions struct {
usage *int64
eventUsage map[string]int64
}

// WithUsage increments the "usage" value for any numeric condition encountered while checking rules
// (trait, metric, credits). This is best for cases where the flag has only one type of rules with
// one type of trait, metric, or credit.
func WithUsage(quantity int64) CheckFlagOption {
return func(o *checkFlagOptions) {
o.usage = &quantity
}
}

// WithEventUsage specifies a specific event subtype, and for credit or metric conditions we check
// as if this additional usage had occurred.
func WithEventUsage(eventSubtype string, quantity int64) CheckFlagOption {
return func(o *checkFlagOptions) {
if o.eventUsage == nil {
o.eventUsage = make(map[string]int64)
}
o.eventUsage[eventSubtype] = quantity
}
}
116 changes: 80 additions & 36 deletions rulecheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
)

type CheckScope struct {
Company *Company
Rule *Rule
User *User
Company *Company
Rule *Rule
User *User
Usage *int64
EventUsage map[string]int64
}

type CheckResult struct {
Expand Down Expand Up @@ -42,14 +44,14 @@ func (s *RuleCheckService) Check(ctx context.Context, scope *CheckScope) (res *C

var match bool
for _, condition := range scope.Rule.Conditions {
match, err = s.checkCondition(ctx, scope.Company, scope.User, condition)
match, err = s.checkCondition(ctx, scope, condition)
if err != nil || !match {
return
}
}

for _, group := range scope.Rule.ConditionGroups {
match, err = s.checkConditionGroup(ctx, scope.Company, scope.User, group)
match, err = s.checkConditionGroup(ctx, scope, group)
if err != nil || !match {
return
}
Expand All @@ -59,43 +61,43 @@ func (s *RuleCheckService) Check(ctx context.Context, scope *CheckScope) (res *C
return
}

func (s *RuleCheckService) checkCondition(ctx context.Context, company *Company, user *User, condition *Condition) (match bool, err error) {
func (s *RuleCheckService) checkCondition(ctx context.Context, scope *CheckScope, condition *Condition) (match bool, err error) {
if condition == nil {
return false, nil
}

switch condition.ConditionType {
case ConditionTypeCompany:
return s.checkCompanyCondition(ctx, company, condition)
return s.checkCompanyCondition(ctx, scope.Company, condition)
case ConditionTypeMetric:
return s.checkMetricCondition(ctx, company, condition)
return s.checkMetricCondition(ctx, scope, condition)
case ConditionTypeBasePlan:
return s.checkBasePlanCondition(ctx, company, condition)
return s.checkBasePlanCondition(ctx, scope.Company, condition)
case ConditionTypePlan:
return s.checkPlanCondition(ctx, company, condition)
return s.checkPlanCondition(ctx, scope.Company, condition)
case ConditionTypeTrait:
return s.checkTraitCondition(ctx, company, user, condition)
return s.checkTraitCondition(ctx, scope, condition)
case ConditionTypeUser:
return s.checkUserCondition(ctx, user, condition)
return s.checkUserCondition(ctx, scope.User, condition)
case ConditionTypeBillingProduct:
return s.checkBillingProductCondition(ctx, company, condition)
return s.checkBillingProductCondition(ctx, scope.Company, condition)
case ConditionTypeCrmProduct:
return s.checkCrmProductCondition(ctx, company, condition)
return s.checkCrmProductCondition(ctx, scope.Company, condition)
case ConditionTypeCredit:
return s.checkCreditBalanceCondition(ctx, company, condition)
return s.checkCreditBalanceCondition(ctx, scope, condition)
}

return
}

func (s *RuleCheckService) checkConditionGroup(ctx context.Context, company *Company, user *User, group *ConditionGroup) (bool, error) {
func (s *RuleCheckService) checkConditionGroup(ctx context.Context, scope *CheckScope, group *ConditionGroup) (bool, error) {
if group == nil {
return false, nil
}

// Condition groups are OR'd together, so we return true if any condition matches
for _, condition := range group.Conditions {
match, err := s.checkCondition(ctx, company, user, condition)
match, err := s.checkCondition(ctx, scope, condition)
if err != nil {
return false, err
}
Expand All @@ -121,25 +123,51 @@ func (s *RuleCheckService) checkCompanyCondition(ctx context.Context, company *C
return resourceMatch, nil
}

func (s *RuleCheckService) checkCreditBalanceCondition(ctx context.Context, company *Company, condition *Condition) (bool, error) {
if condition.ConditionType != ConditionTypeCredit || company == nil || condition.CreditID == nil {
func (s *RuleCheckService) checkCreditBalanceCondition(ctx context.Context, scope *CheckScope, condition *Condition) (bool, error) {
if condition.ConditionType != ConditionTypeCredit || scope.Company == nil || condition.CreditID == nil {
return false, nil
}

var consumptionCost = float64(1)
var consumptionRate = float64(1)
if condition.ConsumptionRate != nil {
consumptionCost = *condition.ConsumptionRate
consumptionRate = *condition.ConsumptionRate
}

var creditBalance float64
for creditID, balance := range company.CreditBalances {
for creditID, balance := range scope.Company.CreditBalances {
if creditID == *condition.CreditID {
creditBalance = balance
break
}
}

return creditBalance >= consumptionCost, nil
// WithUsage: Check if there are enough credits for generic usage
if scope.Usage != nil && *scope.Usage > 0 {
creditsNeeded := float64(*scope.Usage) * consumptionRate
return creditBalance >= creditsNeeded, nil
}

// WithEventUsage: Check if there are enough credits for event-specific usage
if condition.EventSubtype != nil && scope.EventUsage != nil {
if eventUsage, ok := scope.EventUsage[*condition.EventSubtype]; ok && eventUsage > 0 {
creditsNeeded := float64(eventUsage) * consumptionRate
return creditBalance >= creditsNeeded, nil
}
}

// Check against current metric usage if EventSubtype is specified
if condition.EventSubtype != nil {
usage := int64(0)
metric := scope.Company.Metrics.Find(*condition.EventSubtype, condition.MetricPeriod, condition.MetricPeriodMonthReset)
if metric != nil {
usage = metric.Value
}

creditsNeeded := float64(usage) * consumptionRate
return creditBalance >= creditsNeeded, nil
}

return creditBalance >= consumptionRate, nil
}

func (s *RuleCheckService) checkBillingProductCondition(ctx context.Context, company *Company, condition *Condition) (bool, error) {
Expand Down Expand Up @@ -208,26 +236,36 @@ func (s *RuleCheckService) checkBasePlanCondition(ctx context.Context, company *

func (s *RuleCheckService) checkMetricCondition(
ctx context.Context,
company *Company,
scope *CheckScope,
condition *Condition,
) (bool, error) {
if condition == nil || condition.ConditionType != ConditionTypeMetric || company == nil || condition.EventSubtype == nil {
if condition == nil || condition.ConditionType != ConditionTypeMetric || scope.Company == nil || condition.EventSubtype == nil {
return false, nil
}

leftVal := int64(0)
metric := company.Metrics.Find(*condition.EventSubtype, condition.MetricPeriod, condition.MetricPeriodMonthReset)
metric := scope.Company.Metrics.Find(*condition.EventSubtype, condition.MetricPeriod, condition.MetricPeriodMonthReset)
if metric != nil {
leftVal = metric.Value
}

// WithEventUsage: Add event-specific usage if this event matches
if scope.EventUsage != nil {
if eventUsage, ok := scope.EventUsage[*condition.EventSubtype]; ok && eventUsage > 0 {
leftVal += eventUsage
}
} else if scope.Usage != nil && *scope.Usage > 0 {
// WithUsage: Add generic usage to current metric value
leftVal += *scope.Usage
}

if condition.MetricValue == nil {
return false, fmt.Errorf("expected metric value for condition: %s, but received nil ", condition.ID)
}

rightVal := *condition.MetricValue
if condition.ComparisonTraitDefinition != nil {
comparisonTrait := s.findTrait(ctx, condition.ComparisonTraitDefinition, company.Traits)
comparisonTrait := s.findTrait(ctx, condition.ComparisonTraitDefinition, scope.Company.Traits)
if comparisonTrait == nil {
rightVal = 0
} else {
Expand All @@ -239,25 +277,25 @@ func (s *RuleCheckService) checkMetricCondition(

}

func (s *RuleCheckService) checkTraitCondition(ctx context.Context, company *Company, user *User, condition *Condition) (bool, error) {
func (s *RuleCheckService) checkTraitCondition(ctx context.Context, scope *CheckScope, condition *Condition) (bool, error) {
if condition == nil || condition.ConditionType != ConditionTypeTrait || condition.TraitDefinition == nil {
return false, nil
}

traitDef := condition.TraitDefinition
var trait *Trait
var comparisonTrait *Trait
if traitDef.EntityType == EntityTypeCompany && company != nil {
trait = s.findTrait(ctx, traitDef, company.Traits)
comparisonTrait = s.findTrait(ctx, condition.ComparisonTraitDefinition, company.Traits)
} else if traitDef.EntityType == EntityTypeUser && user != nil {
trait = s.findTrait(ctx, traitDef, user.Traits)
comparisonTrait = s.findTrait(ctx, condition.ComparisonTraitDefinition, user.Traits)
if traitDef.EntityType == EntityTypeCompany && scope.Company != nil {
trait = s.findTrait(ctx, traitDef, scope.Company.Traits)
comparisonTrait = s.findTrait(ctx, condition.ComparisonTraitDefinition, scope.Company.Traits)
} else if traitDef.EntityType == EntityTypeUser && scope.User != nil {
trait = s.findTrait(ctx, traitDef, scope.User.Traits)
comparisonTrait = s.findTrait(ctx, condition.ComparisonTraitDefinition, scope.User.Traits)
} else {
return false, nil
}

return s.compareTraits(ctx, condition, trait, comparisonTrait), nil
return s.compareTraitsWithUsage(ctx, scope, condition, trait, comparisonTrait), nil
}

func (s *RuleCheckService) checkUserCondition(ctx context.Context, user *User, condition *Condition) (bool, error) {
Expand All @@ -273,7 +311,7 @@ func (s *RuleCheckService) checkUserCondition(ctx context.Context, user *User, c
return resourceMatch, nil
}

func (s *RuleCheckService) compareTraits(ctx context.Context, condition *Condition, trait *Trait, comparisonTrait *Trait) bool {
func (s *RuleCheckService) compareTraitsWithUsage(ctx context.Context, scope *CheckScope, condition *Condition, trait *Trait, comparisonTrait *Trait) bool {
var leftVal string
rightVal := condition.TraitValue
if trait != nil {
Expand All @@ -288,6 +326,12 @@ func (s *RuleCheckService) compareTraits(ctx context.Context, condition *Conditi
comparableType = trait.TraitDefinition.ComparableType
}

if comparableType == typeconvert.ComparableTypeInt && scope.Usage != nil && *scope.Usage > 0 {
leftNumeric := typeconvert.StringToInt64(leftVal)
leftNumeric += *scope.Usage
leftVal = fmt.Sprintf("%d", leftNumeric)
}

return typeconvert.Compare(leftVal, rightVal, comparableType, condition.Operator)
}

Expand Down
Loading