From c69c5b5cd713f29e8aa5e3a7e29ff6f430d9491f Mon Sep 17 00:00:00 2001 From: buyfakett Date: Thu, 8 May 2025 10:08:36 +0800 Subject: [PATCH] feat: fix config --- config/default.yaml | 15 +++++++ go.mod | 16 +++++++- internal/config/config.go | 80 ++++++++++++++++++++++++-------------- internal/server/checker.go | 12 +++--- internal/server/server.go | 40 +++++++++---------- internal/server/torii.go | 10 ++--- main.go | 19 ++++----- 7 files changed, 122 insertions(+), 70 deletions(-) create mode 100644 config/default.yaml diff --git a/config/default.yaml b/config/default.yaml new file mode 100644 index 0000000..82e1de5 --- /dev/null +++ b/config/default.yaml @@ -0,0 +1,15 @@ +server: + port: "25555" + web_path: "/torii" + rule_path: "/www/server_torii/config/rules" + error_page: "/www/server_torii/config/error_page" + log_path: "/www/server_torii/log/" + node_name: "Server Torii" + connecting_host_headers: + - "Torii-Real-Host" + connecting_ip_headers: + - "Torii-Real-IP" + connecting_uri_headers: + - "Torii-Original-URI" + connecting_captcha_status_headers: + - "Torii-Captcha-Status" \ No newline at end of file diff --git a/go.mod b/go.mod index b7c019b..d79cba4 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,22 @@ go 1.23.5 require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/mssola/useragent v1.0.0 + github.com/spf13/pflag v1.0.6 + github.com/spf13/viper v1.20.1 go.uber.org/zap v1.27.0 gopkg.in/yaml.v3 v3.0.1 ) -require go.uber.org/multierr v1.10.0 // indirect +require ( + github.com/fsnotify/fsnotify v1.8.0 // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/pelletier/go-toml/v2 v2.2.3 // indirect + github.com/sagikazarmark/locafero v0.7.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.12.0 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.21.0 // indirect +) diff --git a/internal/config/config.go b/internal/config/config.go index b5a0f35..3f7b629 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,51 +2,73 @@ package config import ( "bufio" + "bytes" + "fmt" + "github.com/spf13/pflag" + "github.com/spf13/viper" "gopkg.in/yaml.v3" "net" "os" - "path/filepath" "regexp" "server_torii/internal/dataType" "server_torii/internal/utils" "strings" ) -type MainConfig struct { - Port string `yaml:"port"` - WebPath string `yaml:"web_path"` - RulePath string `yaml:"rule_path"` - ErrorPage string `yaml:"error_page"` - LogPath string `yaml:"log_path"` - NodeName string `yaml:"node_name"` - ConnectingHostHeaders []string `yaml:"connecting_host_headers"` - ConnectingIPHeaders []string `yaml:"connecting_ip_headers"` - ConnectingURIHeaders []string `yaml:"connecting_uri_headers"` - ConnectingCaptchaStatusHeaders []string `yaml:"connecting_captcha_status_headers"` +type ServerConfig struct { + Port string `mapstructure:"port"` + WebPath string `mapstructure:"web_path"` + RulePath string `mapstructure:"rule_path"` + ErrorPage string `mapstructure:"error_page"` + LogPath string `mapstructure:"log_path"` + NodeName string `mapstructure:"node_name"` + ConnectingHostHeaders []string `mapstructure:"connecting_host_headers"` + ConnectingIPHeaders []string `mapstructure:"connecting_ip_headers"` + ConnectingURIHeaders []string `mapstructure:"connecting_uri_headers"` + ConnectingCaptchaStatusHeaders []string `mapstructure:"connecting_captcha_status_headers"` } -// LoadMainConfig Read the configuration file and return the configuration object -func LoadMainConfig(basePath string) (*MainConfig, error) { - exePath, err := os.Executable() - if err != nil { - return nil, err - } - if basePath == "" { - basePath = filepath.Dir(exePath) - } - configPath := filepath.Join(basePath, "config", "torii.yml") +type AppConfig struct { + Server ServerConfig `mapstructure:"server"` +} - data, err := os.ReadFile(configPath) - if err != nil { - return nil, err +var Cfg AppConfig + +func InitConfig(defaultConfigContent []byte) { + // 1. 处理命令行参数 + var configFile string + pflag.StringVar(&configFile, "config", "", "Path to custom config file") + pflag.Parse() + + v := viper.New() + + // 2. 加载嵌入的默认配置文件 + if len(defaultConfigContent) > 0 { + v.SetConfigType("yml") + if err := v.ReadConfig(bytes.NewBuffer(defaultConfigContent)); err != nil { + fmt.Printf("加载默认配置失败: %v\n", err) + os.Exit(1) + } } - var cfg MainConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, err + // 3. 加载外部配置文件(如果存在) + if configFile != "" { + if _, err := os.Stat(configFile); err == nil { + v.SetConfigFile(configFile) + if err := v.MergeInConfig(); err != nil { + fmt.Printf("加载外部配置失败: %v (路径: %s)\n", err, configFile) + os.Exit(1) + } + } else { + fmt.Printf("警告: 外部配置文件不存在,使用默认配置 (路径: %s)\n", configFile) + } } - return &cfg, nil + // 4. 映射到结构体 + if err := v.Unmarshal(&Cfg); err != nil { + fmt.Println("解析配置失败:", err) + os.Exit(1) + } } // RuleSet stores all rules diff --git a/internal/server/checker.go b/internal/server/checker.go index 679c39b..9432448 100644 --- a/internal/server/checker.go +++ b/internal/server/checker.go @@ -15,7 +15,7 @@ import ( type CheckFunc func(dataType.UserRequest, *config.RuleSet, *action.Decision, *dataType.SharedMemory) -func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { +func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() checkFuncs := make([]CheckFunc, 0) @@ -42,7 +42,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule return } } else if bytes.Compare(decision.HTTPCode, []byte("403")) == 0 { - tpl, err := template.ParseFiles(cfg.ErrorPage + "/403.html") + tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/403.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -54,7 +54,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule ConnectIP string Date string }{ - EdgeTag: cfg.NodeName, + EdgeTag: config.Cfg.Server.NodeName, ConnectIP: userRequestData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } @@ -67,7 +67,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule } } else if bytes.Compare(decision.HTTPCode, []byte("CAPTCHA")) == 0 { - tpl, err := template.ParseFiles(cfg.ErrorPage + "/CAPTCHA.html") + tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/CAPTCHA.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -83,7 +83,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule } } else if bytes.Compare(decision.HTTPCode, []byte("429")) == 0 { - tpl, err := template.ParseFiles(cfg.ErrorPage + "/429.html") + tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/429.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -94,7 +94,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule ConnectIP string Date string }{ - EdgeTag: cfg.NodeName, + EdgeTag: config.Cfg.Server.NodeName, ConnectIP: userRequestData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } diff --git a/internal/server/server.go b/internal/server/server.go index 7942343..efdcab5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,33 +10,33 @@ import ( ) // StartServer starts the HTTP server -func StartServer(cfg *config.MainConfig, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) error { +func StartServer(ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) error { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - userRequestData := processRequestData(cfg, r) + userRequestData := processRequestData(r) - if strings.HasPrefix(userRequestData.Uri, cfg.WebPath) { - CheckTorii(w, r, userRequestData, ruleSet, cfg, sharedMem) + if strings.HasPrefix(userRequestData.Uri, config.Cfg.Server.WebPath) { + CheckTorii(w, r, userRequestData, ruleSet, sharedMem) } else { - CheckMain(w, userRequestData, ruleSet, cfg, sharedMem) + CheckMain(w, userRequestData, ruleSet, sharedMem) } }) - log.Printf("HTTP Server listening on :%s ...", cfg.Port) - return http.ListenAndServe(":"+cfg.Port, nil) + log.Printf("HTTP Server listening on :%s ...", config.Cfg.Server.Port) + return http.ListenAndServe(":"+config.Cfg.Server.Port, nil) } -func processRequestData(cfg *config.MainConfig, r *http.Request) dataType.UserRequest { +func processRequestData(r *http.Request) dataType.UserRequest { userRequest := dataType.UserRequest{ - RemoteIP: getClientIP(cfg, r), - Uri: getReqURI(cfg, r), - Captcha: getCaptchaStatus(cfg, r), + RemoteIP: getClientIP(r), + Uri: getReqURI(r), + Captcha: getCaptchaStatus(r), ToriiClearance: getHeader(r, "__torii_clearance"), ToriiSessionID: getHeader(r, "__torii_session_id"), UserAgent: r.UserAgent(), - Host: getReqHost(cfg, r), + Host: getReqHost(r), } return userRequest } @@ -49,9 +49,9 @@ func getHeader(r *http.Request, headerName string) string { return cookie.Value } -func getCaptchaStatus(cfg *config.MainConfig, r *http.Request) bool { +func getCaptchaStatus(r *http.Request) bool { captchaStatus := false - for _, headerName := range cfg.ConnectingCaptchaStatusHeaders { + for _, headerName := range config.Cfg.Server.ConnectingCaptchaStatusHeaders { if captchaVal := r.Header.Get(headerName); captchaVal != "" { if captchaVal == "on" { captchaStatus = true @@ -63,9 +63,9 @@ func getCaptchaStatus(cfg *config.MainConfig, r *http.Request) bool { } -func getReqURI(cfg *config.MainConfig, r *http.Request) string { +func getReqURI(r *http.Request) string { var clientURI string - for _, headerName := range cfg.ConnectingURIHeaders { + for _, headerName := range config.Cfg.Server.ConnectingURIHeaders { if uriVal := r.Header.Get(headerName); uriVal != "" { clientURI = uriVal break @@ -77,9 +77,9 @@ func getReqURI(cfg *config.MainConfig, r *http.Request) string { return clientURI } -func getClientIP(cfg *config.MainConfig, r *http.Request) string { +func getClientIP(r *http.Request) string { var clientIP string - for _, headerName := range cfg.ConnectingIPHeaders { + for _, headerName := range config.Cfg.Server.ConnectingIPHeaders { if ipVal := r.Header.Get(headerName); ipVal != "" { if strings.Contains(clientIP, ",") { parts := strings.Split(ipVal, ",") @@ -102,9 +102,9 @@ func getClientIP(cfg *config.MainConfig, r *http.Request) string { return clientIP } -func getReqHost(cfg *config.MainConfig, r *http.Request) string { +func getReqHost(r *http.Request) string { var clientHost = "" - for _, headerName := range cfg.ConnectingHostHeaders { + for _, headerName := range config.Cfg.Server.ConnectingHostHeaders { if hostVal := r.Header.Get(headerName); hostVal != "" { clientHost = hostVal break diff --git a/internal/server/torii.go b/internal/server/torii.go index 830b8fa..3ad1a99 100644 --- a/internal/server/torii.go +++ b/internal/server/torii.go @@ -12,13 +12,13 @@ import ( "time" ) -func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { +func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() decision.SetCode(action.Continue, []byte("403")) - if reqData.Uri == cfg.WebPath+"/captcha" { + if reqData.Uri == config.Cfg.Server.WebPath+"/captcha" { check.CheckCaptcha(r, reqData, ruleSet, decision) - } else if reqData.Uri == cfg.WebPath+"/health_check" { + } else if reqData.Uri == config.Cfg.Server.WebPath+"/health_check" { decision.SetResponse(action.Done, []byte("200"), []byte("ok")) } if bytes.Compare(decision.HTTPCode, []byte("200")) == 0 { @@ -65,7 +65,7 @@ func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserReq } } } else { - tpl, err := template.ParseFiles(cfg.ErrorPage + "/403.html") + tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/403.html") if err != nil { http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) return @@ -76,7 +76,7 @@ func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserReq ConnectIP string Date string }{ - EdgeTag: cfg.NodeName, + EdgeTag: config.Cfg.Server.NodeName, ConnectIP: reqData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } diff --git a/main.go b/main.go index ac72194..ed3c346 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + _ "embed" "flag" "log" "os" @@ -15,6 +16,9 @@ import ( "time" ) +//go:embed config/default.yaml +var defaultConfigContent []byte + func main() { var basePath string flag.StringVar(&basePath, "prefix", "", "Config file base path") @@ -25,21 +29,18 @@ func main() { } // Load MainConfig - cfg, err := config.LoadMainConfig(basePath) - if err != nil { - log.Fatalf("Load config failed: %v", err) - } + config.InitConfig(defaultConfigContent) // Load rules - ruleSet, err := config.LoadRules(cfg.RulePath) + ruleSet, err := config.LoadRules(config.Cfg.Server.RulePath) if err != nil { log.Fatalf("Load rules failed: %v", err) } - log.Printf("Ready to start server on port %s", cfg.Port) + log.Printf("Ready to start server on port %s", config.Cfg.Server.Port) //set log file - defaultLogPath := filepath.Join(cfg.LogPath + "server_torii.log") + defaultLogPath := filepath.Join(config.Cfg.Server.LogPath + "server_torii.log") logFile, err := os.OpenFile(defaultLogPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { log.Fatalf("Failed to open log file: %v", err) @@ -52,7 +53,7 @@ func main() { }(logFile) log.SetOutput(logFile) - utils.InitLogx(cfg.LogPath) + utils.InitLogx(config.Cfg.Server.LogPath) //allocate shared memory sharedMem := &dataType.SharedMemory{ @@ -71,7 +72,7 @@ func main() { serverErr := make(chan error, 1) go func() { - serverErr <- server.StartServer(cfg, ruleSet, sharedMem) + serverErr <- server.StartServer(ruleSet, sharedMem) }() select {