From c7e9a69f8bc69a6fb00148fd1042c0cd936fe7ef Mon Sep 17 00:00:00 2001 From: Roi Feng <37480123+Rayzggz@users.noreply.github.com> Date: Thu, 13 Feb 2025 21:08:09 -0500 Subject: [PATCH] fix: Obtain User request URI --- config/torii.yml | 4 +- internal/config/config.go | 7 +-- internal/server/server.go | 90 +++++++++++++++++++++++++-------------- main.go | 2 +- 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/config/torii.yml b/config/torii.yml index ee70c1e..303d62c 100644 --- a/config/torii.yml +++ b/config/torii.yml @@ -1,4 +1,6 @@ port: "25555" rule_path: "/www/dev/server_torii/config/rules" connecting_ip_headers: - - "X-Real-IP" \ No newline at end of file + - "X-Real-IP" +connecting_uri_headers: + - "X-Original-URI" \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index fb5fe4f..5489baa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,9 +12,10 @@ import ( ) type MainConfig struct { - Port string `yaml:"port"` - RulePath string `yaml:"rule_path"` - ConnectingIPHeaders []string `yaml:"connecting_ip_headers"` + Port string `yaml:"port"` + RulePath string `yaml:"rule_path"` + ConnectingIPHeaders []string `yaml:"connecting_ip_headers"` + ConnectingURIHeaders []string `yaml:"connecting_uri_headers"` } // LoadMainConfig Read the configuration file and return the configuration object diff --git a/internal/server/server.go b/internal/server/server.go index 47dea66..be8a206 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,45 +10,30 @@ import ( "strings" ) -// StartServer starts the HTTP server -func StartServer(port string, ruleSet *config.RuleSet, ipHeaders []string) error { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - var clientIP string - for _, headerName := range ipHeaders { - if ipVal := r.Header.Get(headerName); ipVal != "" { - if strings.Contains(clientIP, ",") { - parts := strings.Split(ipVal, ",") - clientIP = strings.TrimSpace(parts[0]) - } - clientIP = ipVal - break - } - } +type userRequest struct { + remoteIP string + uri string +} - if clientIP == "" { - remoteAddr := r.RemoteAddr - ipStr, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - //TODO: log error - clientIP = remoteAddr - } else { - clientIP = ipStr - } - } +// StartServer starts the HTTP server +func StartServer(cfg *config.MainConfig, ruleSet *config.RuleSet) error { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + + userRequestData := processRequestData(cfg, r) decision := action.NewDecision() // run main check logic - checkIPAllow(clientIP, ruleSet.IPAllowTrie, decision) - checkIPBlock(clientIP, ruleSet.IPBlockTrie, decision) - checkURLAllow(r.RequestURI, ruleSet.URLAllowList, decision) - checkURLBlock(r.RequestURI, ruleSet.URLBlockList, decision) + checkIPAllow(userRequestData.remoteIP, ruleSet.IPAllowTrie, decision) + checkIPBlock(userRequestData.remoteIP, ruleSet.IPBlockTrie, decision) + checkURLAllow(userRequestData.uri, ruleSet.URLAllowList, decision) + checkURLBlock(userRequestData.uri, ruleSet.URLBlockList, decision) // if still undecided, allow if decision.Get() == action.Undecided { decision.Set(action.Allow) } - + log.Printf("clientIP: %s, decision: %s, Headers: %v", userRequestData.remoteIP, decision.Get(), r.Header) // return response if decision.Get() == action.Allow { w.WriteHeader(http.StatusOK) @@ -62,8 +47,51 @@ func StartServer(port string, ruleSet *config.RuleSet, ipHeaders []string) error } }) - log.Printf("HTTP Server listening on :%s ...", port) - return http.ListenAndServe(":"+port, nil) + log.Printf("HTTP Server listening on :%s ...", cfg.Port) + return http.ListenAndServe(":"+cfg.Port, nil) +} + +func processRequestData(cfg *config.MainConfig, r *http.Request) userRequest { + + var clientIP string + for _, headerName := range cfg.ConnectingIPHeaders { + if ipVal := r.Header.Get(headerName); ipVal != "" { + if strings.Contains(clientIP, ",") { + parts := strings.Split(ipVal, ",") + clientIP = strings.TrimSpace(parts[0]) + } + clientIP = ipVal + break + } + } + + if clientIP == "" { + remoteAddr := r.RemoteAddr + ipStr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + //TODO: log error + clientIP = remoteAddr + } else { + clientIP = ipStr + } + } + + var clientURI string + for _, headerName := range cfg.ConnectingURIHeaders { + if uriVal := r.Header.Get(headerName); uriVal != "" { + clientURI = uriVal + break + } + } + if clientURI == "" { + clientURI = r.RequestURI + } + + userRequest := userRequest{ + remoteIP: clientIP, + uri: clientURI, + } + return userRequest } func checkIPAllow(remoteIP string, trie *dataType.TrieNode, decision *action.Decision) { diff --git a/main.go b/main.go index 31a5f22..0e6a26a 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,7 @@ func main() { serverErr := make(chan error, 1) go func() { - serverErr <- server.StartServer(cfg.Port, ruleSet, cfg.ConnectingIPHeaders) + serverErr <- server.StartServer(cfg, ruleSet) }() select {