diff --git a/internal/check/Captcha.go b/internal/check/Captcha.go index 97178b4..daee3f9 100644 --- a/internal/check/Captcha.go +++ b/internal/check/Captcha.go @@ -25,7 +25,7 @@ type HCaptchaResponse struct { ErrorCodes []string `json:"error-codes"` } -func Captcha(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func Captcha(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { if !reqData.Captcha { decision.Set(action.Continue) return diff --git a/internal/check/HTTPFlood.go b/internal/check/HTTPFlood.go new file mode 100644 index 0000000..d088ff2 --- /dev/null +++ b/internal/check/HTTPFlood.go @@ -0,0 +1,35 @@ +package check + +import ( + "log" + "server_torii/internal/action" + "server_torii/internal/config" + "server_torii/internal/dataType" +) + +func HTTPFlood(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { + ipKey := reqData.RemoteIP + sharedMem.HTTPFloodSpeedLimitCounter.Add(ipKey, 1) + + uriKey := reqData.RemoteIP + "|" + reqData.Uri + sharedMem.HTTPFloodSameURILimitCounter.Add(uriKey, 1) + + for window, limit := range ruleSet.HTTPFloodRule.HTTPFloodSpeedLimit { + if sharedMem.HTTPFloodSpeedLimitCounter.Query(ipKey, window) > limit { + log.Printf("HTTPFlood rate limit exceeded: IP %s, window %d, limit %d", ipKey, window, limit) + //decision.SetResponse(action.Done, []byte("403"), nil) + decision.Set(action.Continue) + return + } + } + + for window, limit := range ruleSet.HTTPFloodRule.HTTPFloodSameURILimit { + if sharedMem.HTTPFloodSameURILimitCounter.Query(uriKey, window) > limit { + log.Printf("HTTPFlood URI rate limit exceeded: IP %s, URI %s, window %d, limit %d", ipKey, reqData.Uri, window, limit) + //decision.SetResponse(action.Done, []byte("403"), nil) + decision.Set(action.Continue) + return + } + } + decision.Set(action.Continue) +} diff --git a/internal/check/IPAllow.go b/internal/check/IPAllow.go index bc90a2b..dfb7417 100644 --- a/internal/check/IPAllow.go +++ b/internal/check/IPAllow.go @@ -7,7 +7,7 @@ import ( "server_torii/internal/dataType" ) -func IPAllowList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func IPAllowList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { remoteIP := reqData.RemoteIP trie := ruleSet.IPAllowTrie diff --git a/internal/check/IPBlock.go b/internal/check/IPBlock.go index 291bafc..490dddd 100644 --- a/internal/check/IPBlock.go +++ b/internal/check/IPBlock.go @@ -7,7 +7,7 @@ import ( "server_torii/internal/dataType" ) -func IPBlockList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func IPBlockList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { remoteIP := reqData.RemoteIP trie := ruleSet.IPBlockTrie ip := net.ParseIP(remoteIP) diff --git a/internal/check/URLAllow.go b/internal/check/URLAllow.go index e80cf0c..fd89ecd 100644 --- a/internal/check/URLAllow.go +++ b/internal/check/URLAllow.go @@ -6,7 +6,7 @@ import ( "server_torii/internal/dataType" ) -func URLAllowList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func URLAllowList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { url := reqData.Uri list := ruleSet.URLAllowList if list.Match(url) { diff --git a/internal/check/URLBlock.go b/internal/check/URLBlock.go index 55f6be3..e8143cb 100644 --- a/internal/check/URLBlock.go +++ b/internal/check/URLBlock.go @@ -6,7 +6,7 @@ import ( "server_torii/internal/dataType" ) -func URLBlockList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func URLBlockList(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { url := reqData.Uri list := ruleSet.URLBlockList if list.Match(url) { diff --git a/internal/check/VerifyBot.go b/internal/check/VerifyBot.go index 0d6dc7a..d19afa7 100644 --- a/internal/check/VerifyBot.go +++ b/internal/check/VerifyBot.go @@ -9,7 +9,7 @@ import ( "strings" ) -func VerifyBot(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision) { +func VerifyBot(reqData dataType.UserRequest, ruleSet *config.RuleSet, decision *action.Decision, sharedMem *dataType.SharedMemory) { ua := strings.ToLower(reqData.UserAgent) var exptractRDNS []string diff --git a/internal/config/config.go b/internal/config/config.go index d00fb48..b5a0f35 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -235,8 +235,8 @@ func loadHTTPFloodRule(file string, rule *dataType.HTTPFloodRule) error { return err } - rule.HTTPFloodSpeedLimit = make(map[int]int) - rule.HTTPFloodSameURILimit = make(map[int]int) + rule.HTTPFloodSpeedLimit = make(map[int64]int64) + rule.HTTPFloodSameURILimit = make(map[int64]int64) for _, s := range ymlRule.HTTPFloodSpeedLimit { limit, seconds, err := utils.ParseRate(s) diff --git a/internal/dataType/type.go b/internal/dataType/type.go index f522f20..5c77443 100644 --- a/internal/dataType/type.go +++ b/internal/dataType/type.go @@ -27,9 +27,11 @@ type VerifyBotRule struct { } type HTTPFloodRule struct { - HTTPFloodSpeedLimit map[int]int - HTTPFloodSameURILimit map[int]int + HTTPFloodSpeedLimit map[int64]int64 + HTTPFloodSameURILimit map[int64]int64 } type SharedMemory struct { + HTTPFloodSpeedLimitCounter *Counter + HTTPFloodSameURILimitCounter *Counter } diff --git a/internal/server/checker.go b/internal/server/checker.go index 9946e82..ffb835e 100644 --- a/internal/server/checker.go +++ b/internal/server/checker.go @@ -12,9 +12,9 @@ import ( "time" ) -type CheckFunc func(dataType.UserRequest, *config.RuleSet, *action.Decision) +type CheckFunc func(dataType.UserRequest, *config.RuleSet, *action.Decision, *dataType.SharedMemory) -func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig) { +func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() checkFuncs := make([]CheckFunc, 0) @@ -23,10 +23,11 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule checkFuncs = append(checkFuncs, check.URLAllowList) checkFuncs = append(checkFuncs, check.URLBlockList) checkFuncs = append(checkFuncs, check.VerifyBot) + checkFuncs = append(checkFuncs, check.HTTPFlood) checkFuncs = append(checkFuncs, check.Captcha) for _, checkFunc := range checkFuncs { - checkFunc(userRequestData, ruleSet, decision) + checkFunc(userRequestData, ruleSet, decision, sharedMem) if decision.State == action.Done { break } diff --git a/internal/server/server.go b/internal/server/server.go index cb438d6..7942343 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,15 +10,15 @@ import ( ) // StartServer starts the HTTP server -func StartServer(cfg *config.MainConfig, ruleSet *config.RuleSet) error { +func StartServer(cfg *config.MainConfig, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) error { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { userRequestData := processRequestData(cfg, r) if strings.HasPrefix(userRequestData.Uri, cfg.WebPath) { - CheckTorii(w, r, userRequestData, ruleSet, cfg) + CheckTorii(w, r, userRequestData, ruleSet, cfg, sharedMem) } else { - CheckMain(w, userRequestData, ruleSet, cfg) + CheckMain(w, userRequestData, ruleSet, cfg, sharedMem) } }) diff --git a/internal/server/torii.go b/internal/server/torii.go index e314732..e50c91f 100644 --- a/internal/server/torii.go +++ b/internal/server/torii.go @@ -12,7 +12,7 @@ import ( "time" ) -func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig) { +func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() decision.SetCode(action.Continue, []byte("403")) diff --git a/internal/utils/RateParser.go b/internal/utils/RateParser.go index 8f4fb36..041964e 100644 --- a/internal/utils/RateParser.go +++ b/internal/utils/RateParser.go @@ -6,7 +6,7 @@ import ( "strings" ) -func ParseRate(s string) (int, int, error) { +func ParseRate(s string) (int64, int64, error) { parts := strings.Split(s, "/") if len(parts) != 2 { return 0, 0, fmt.Errorf("unexpected rate format: %s", s) @@ -37,5 +37,5 @@ func ParseRate(s string) (int, int, error) { default: return 0, 0, fmt.Errorf("unexpected time unit: %s", string(unit)) } - return limit, seconds, nil + return int64(limit), int64(seconds), nil } diff --git a/main.go b/main.go index 6d619b2..9d7c3ed 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "server_torii/internal/config" + "server_torii/internal/dataType" "server_torii/internal/server" "syscall" ) @@ -47,6 +48,10 @@ func main() { log.SetOutput(logFile) //allocate shared memory + sharedMem := &dataType.SharedMemory{ + HTTPFloodSpeedLimitCounter: dataType.NewCounter(64, 60), + HTTPFloodSameURILimitCounter: dataType.NewCounter(64, 60), + } // Start server stop := make(chan os.Signal, 1) @@ -54,7 +59,7 @@ func main() { serverErr := make(chan error, 1) go func() { - serverErr <- server.StartServer(cfg, ruleSet) + serverErr <- server.StartServer(cfg, ruleSet, sharedMem) }() select {