diff --git a/internal/dataType/rateCounter.go b/internal/dataType/rateCounter.go index 597bd77..c275290 100644 --- a/internal/dataType/rateCounter.go +++ b/internal/dataType/rateCounter.go @@ -165,3 +165,16 @@ func (tc *Counter) GC() { bucket.mu.Unlock() } } + +func StartCounterGC(counter *Counter, interval time.Duration, stopCh <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + counter.GC() + case <-stopCh: + return + } + } +} diff --git a/internal/utils/RateParser.go b/internal/utils/RateParser.go index 041964e..1287ec2 100644 --- a/internal/utils/RateParser.go +++ b/internal/utils/RateParser.go @@ -39,3 +39,13 @@ func ParseRate(s string) (int64, int64, error) { } return int64(limit), int64(seconds), nil } + +func FindMaxRateTime(rateList map[int64]int64) int64 { + maxTimeWindow := int64(0) + for window := range rateList { + if window > maxTimeWindow { + maxTimeWindow = window + } + } + return maxTimeWindow +} diff --git a/main.go b/main.go index 9d7c3ed..0a93c57 100644 --- a/main.go +++ b/main.go @@ -5,10 +5,13 @@ import ( "log" "os" "os/signal" + "runtime" "server_torii/internal/config" "server_torii/internal/dataType" "server_torii/internal/server" + "server_torii/internal/utils" "syscall" + "time" ) func main() { @@ -49,10 +52,15 @@ func main() { //allocate shared memory sharedMem := &dataType.SharedMemory{ - HTTPFloodSpeedLimitCounter: dataType.NewCounter(64, 60), - HTTPFloodSameURILimitCounter: dataType.NewCounter(64, 60), + HTTPFloodSpeedLimitCounter: dataType.NewCounter(max(runtime.NumCPU()*8, 16), utils.FindMaxRateTime(ruleSet.HTTPFloodRule.HTTPFloodSpeedLimit)), + HTTPFloodSameURILimitCounter: dataType.NewCounter(max(runtime.NumCPU()*8, 16), utils.FindMaxRateTime(ruleSet.HTTPFloodRule.HTTPFloodSameURILimit)), } + //GC + gcStopCh := make(chan struct{}) + go dataType.StartCounterGC(sharedMem.HTTPFloodSpeedLimitCounter, time.Minute, gcStopCh) + go dataType.StartCounterGC(sharedMem.HTTPFloodSameURILimitCounter, time.Minute, gcStopCh) + // Start server stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) @@ -65,6 +73,7 @@ func main() { select { case <-stop: log.Println("Stopping server...") + close(gcStopCh) case err := <-serverErr: if err != nil { log.Fatalf("Failed to start server: %v", err)