diff --git a/config/default.yaml b/config/default.yaml deleted file mode 100644 index 82e1de5..0000000 --- a/config/default.yaml +++ /dev/null @@ -1,15 +0,0 @@ -server: - port: "25555" - web_path: "/torii" - rule_path: "/www/server_torii/config/rules" - error_page: "/www/server_torii/config/error_page" - log_path: "/www/server_torii/log/" - node_name: "Server Torii" - connecting_host_headers: - - "Torii-Real-Host" - connecting_ip_headers: - - "Torii-Real-IP" - connecting_uri_headers: - - "Torii-Original-URI" - connecting_captcha_status_headers: - - "Torii-Captcha-Status" \ No newline at end of file diff --git a/go.mod b/go.mod index d79cba4..b7c019b 100644 --- a/go.mod +++ b/go.mod @@ -5,22 +5,8 @@ go 1.23.5 require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/mssola/useragent v1.0.0 - github.com/spf13/pflag v1.0.6 - github.com/spf13/viper v1.20.1 go.uber.org/zap v1.27.0 gopkg.in/yaml.v3 v3.0.1 ) -require ( - github.com/fsnotify/fsnotify v1.8.0 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.3 // indirect - github.com/sagikazarmark/locafero v0.7.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.12.0 // indirect - github.com/spf13/cast v1.7.1 // indirect - github.com/subosito/gotenv v1.6.0 // indirect - go.uber.org/multierr v1.10.0 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.21.0 // indirect -) +require go.uber.org/multierr v1.10.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index 3f7b629..b5a0f35 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,73 +2,51 @@ package config import ( "bufio" - "bytes" - "fmt" - "github.com/spf13/pflag" - "github.com/spf13/viper" "gopkg.in/yaml.v3" "net" "os" + "path/filepath" "regexp" "server_torii/internal/dataType" "server_torii/internal/utils" "strings" ) -type ServerConfig struct { - Port string `mapstructure:"port"` - WebPath string `mapstructure:"web_path"` - RulePath string `mapstructure:"rule_path"` - ErrorPage string `mapstructure:"error_page"` - LogPath string `mapstructure:"log_path"` - NodeName string `mapstructure:"node_name"` - ConnectingHostHeaders []string `mapstructure:"connecting_host_headers"` - ConnectingIPHeaders []string `mapstructure:"connecting_ip_headers"` - ConnectingURIHeaders []string `mapstructure:"connecting_uri_headers"` - ConnectingCaptchaStatusHeaders []string `mapstructure:"connecting_captcha_status_headers"` +type MainConfig struct { + Port string `yaml:"port"` + WebPath string `yaml:"web_path"` + RulePath string `yaml:"rule_path"` + ErrorPage string `yaml:"error_page"` + LogPath string `yaml:"log_path"` + NodeName string `yaml:"node_name"` + ConnectingHostHeaders []string `yaml:"connecting_host_headers"` + ConnectingIPHeaders []string `yaml:"connecting_ip_headers"` + ConnectingURIHeaders []string `yaml:"connecting_uri_headers"` + ConnectingCaptchaStatusHeaders []string `yaml:"connecting_captcha_status_headers"` } -type AppConfig struct { - Server ServerConfig `mapstructure:"server"` -} +// LoadMainConfig Read the configuration file and return the configuration object +func LoadMainConfig(basePath string) (*MainConfig, error) { + exePath, err := os.Executable() + if err != nil { + return nil, err + } + if basePath == "" { + basePath = filepath.Dir(exePath) + } + configPath := filepath.Join(basePath, "config", "torii.yml") -var Cfg AppConfig - -func InitConfig(defaultConfigContent []byte) { - // 1. 处理命令行参数 - var configFile string - pflag.StringVar(&configFile, "config", "", "Path to custom config file") - pflag.Parse() - - v := viper.New() - - // 2. 加载嵌入的默认配置文件 - if len(defaultConfigContent) > 0 { - v.SetConfigType("yml") - if err := v.ReadConfig(bytes.NewBuffer(defaultConfigContent)); err != nil { - fmt.Printf("加载默认配置失败: %v\n", err) - os.Exit(1) - } + data, err := os.ReadFile(configPath) + if err != nil { + return nil, err } - // 3. 加载外部配置文件(如果存在) - if configFile != "" { - if _, err := os.Stat(configFile); err == nil { - v.SetConfigFile(configFile) - if err := v.MergeInConfig(); err != nil { - fmt.Printf("加载外部配置失败: %v (路径: %s)\n", err, configFile) - os.Exit(1) - } - } else { - fmt.Printf("警告: 外部配置文件不存在,使用默认配置 (路径: %s)\n", configFile) - } + var cfg MainConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err } - // 4. 映射到结构体 - if err := v.Unmarshal(&Cfg); err != nil { - fmt.Println("解析配置失败:", err) - os.Exit(1) - } + return &cfg, nil } // RuleSet stores all rules diff --git a/internal/server/checker.go b/internal/server/checker.go index 9432448..679c39b 100644 --- a/internal/server/checker.go +++ b/internal/server/checker.go @@ -15,7 +15,7 @@ import ( type CheckFunc func(dataType.UserRequest, *config.RuleSet, *action.Decision, *dataType.SharedMemory) -func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) { +func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() checkFuncs := make([]CheckFunc, 0) @@ -42,7 +42,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule return } } else if bytes.Compare(decision.HTTPCode, []byte("403")) == 0 { - tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/403.html") + tpl, err := template.ParseFiles(cfg.ErrorPage + "/403.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -54,7 +54,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule ConnectIP string Date string }{ - EdgeTag: config.Cfg.Server.NodeName, + EdgeTag: cfg.NodeName, ConnectIP: userRequestData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } @@ -67,7 +67,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule } } else if bytes.Compare(decision.HTTPCode, []byte("CAPTCHA")) == 0 { - tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/CAPTCHA.html") + tpl, err := template.ParseFiles(cfg.ErrorPage + "/CAPTCHA.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -83,7 +83,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule } } else if bytes.Compare(decision.HTTPCode, []byte("429")) == 0 { - tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/429.html") + tpl, err := template.ParseFiles(cfg.ErrorPage + "/429.html") if err != nil { utils.LogError(userRequestData, fmt.Sprintf("Error parsing template: %v", err), "CheckMain") http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) @@ -94,7 +94,7 @@ func CheckMain(w http.ResponseWriter, userRequestData dataType.UserRequest, rule ConnectIP string Date string }{ - EdgeTag: config.Cfg.Server.NodeName, + EdgeTag: cfg.NodeName, ConnectIP: userRequestData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } diff --git a/internal/server/server.go b/internal/server/server.go index efdcab5..7942343 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,33 +10,33 @@ import ( ) // StartServer starts the HTTP server -func StartServer(ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) error { +func StartServer(cfg *config.MainConfig, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) error { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - userRequestData := processRequestData(r) + userRequestData := processRequestData(cfg, r) - if strings.HasPrefix(userRequestData.Uri, config.Cfg.Server.WebPath) { - CheckTorii(w, r, userRequestData, ruleSet, sharedMem) + if strings.HasPrefix(userRequestData.Uri, cfg.WebPath) { + CheckTorii(w, r, userRequestData, ruleSet, cfg, sharedMem) } else { - CheckMain(w, userRequestData, ruleSet, sharedMem) + CheckMain(w, userRequestData, ruleSet, cfg, sharedMem) } }) - log.Printf("HTTP Server listening on :%s ...", config.Cfg.Server.Port) - return http.ListenAndServe(":"+config.Cfg.Server.Port, nil) + log.Printf("HTTP Server listening on :%s ...", cfg.Port) + return http.ListenAndServe(":"+cfg.Port, nil) } -func processRequestData(r *http.Request) dataType.UserRequest { +func processRequestData(cfg *config.MainConfig, r *http.Request) dataType.UserRequest { userRequest := dataType.UserRequest{ - RemoteIP: getClientIP(r), - Uri: getReqURI(r), - Captcha: getCaptchaStatus(r), + RemoteIP: getClientIP(cfg, r), + Uri: getReqURI(cfg, r), + Captcha: getCaptchaStatus(cfg, r), ToriiClearance: getHeader(r, "__torii_clearance"), ToriiSessionID: getHeader(r, "__torii_session_id"), UserAgent: r.UserAgent(), - Host: getReqHost(r), + Host: getReqHost(cfg, r), } return userRequest } @@ -49,9 +49,9 @@ func getHeader(r *http.Request, headerName string) string { return cookie.Value } -func getCaptchaStatus(r *http.Request) bool { +func getCaptchaStatus(cfg *config.MainConfig, r *http.Request) bool { captchaStatus := false - for _, headerName := range config.Cfg.Server.ConnectingCaptchaStatusHeaders { + for _, headerName := range cfg.ConnectingCaptchaStatusHeaders { if captchaVal := r.Header.Get(headerName); captchaVal != "" { if captchaVal == "on" { captchaStatus = true @@ -63,9 +63,9 @@ func getCaptchaStatus(r *http.Request) bool { } -func getReqURI(r *http.Request) string { +func getReqURI(cfg *config.MainConfig, r *http.Request) string { var clientURI string - for _, headerName := range config.Cfg.Server.ConnectingURIHeaders { + for _, headerName := range cfg.ConnectingURIHeaders { if uriVal := r.Header.Get(headerName); uriVal != "" { clientURI = uriVal break @@ -77,9 +77,9 @@ func getReqURI(r *http.Request) string { return clientURI } -func getClientIP(r *http.Request) string { +func getClientIP(cfg *config.MainConfig, r *http.Request) string { var clientIP string - for _, headerName := range config.Cfg.Server.ConnectingIPHeaders { + for _, headerName := range cfg.ConnectingIPHeaders { if ipVal := r.Header.Get(headerName); ipVal != "" { if strings.Contains(clientIP, ",") { parts := strings.Split(ipVal, ",") @@ -102,9 +102,9 @@ func getClientIP(r *http.Request) string { return clientIP } -func getReqHost(r *http.Request) string { +func getReqHost(cfg *config.MainConfig, r *http.Request) string { var clientHost = "" - for _, headerName := range config.Cfg.Server.ConnectingHostHeaders { + for _, headerName := range cfg.ConnectingHostHeaders { if hostVal := r.Header.Get(headerName); hostVal != "" { clientHost = hostVal break diff --git a/internal/server/torii.go b/internal/server/torii.go index 3ad1a99..830b8fa 100644 --- a/internal/server/torii.go +++ b/internal/server/torii.go @@ -12,13 +12,13 @@ import ( "time" ) -func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, sharedMem *dataType.SharedMemory) { +func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserRequest, ruleSet *config.RuleSet, cfg *config.MainConfig, sharedMem *dataType.SharedMemory) { decision := action.NewDecision() decision.SetCode(action.Continue, []byte("403")) - if reqData.Uri == config.Cfg.Server.WebPath+"/captcha" { + if reqData.Uri == cfg.WebPath+"/captcha" { check.CheckCaptcha(r, reqData, ruleSet, decision) - } else if reqData.Uri == config.Cfg.Server.WebPath+"/health_check" { + } else if reqData.Uri == cfg.WebPath+"/health_check" { decision.SetResponse(action.Done, []byte("200"), []byte("ok")) } if bytes.Compare(decision.HTTPCode, []byte("200")) == 0 { @@ -65,7 +65,7 @@ func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserReq } } } else { - tpl, err := template.ParseFiles(config.Cfg.Server.ErrorPage + "/403.html") + tpl, err := template.ParseFiles(cfg.ErrorPage + "/403.html") if err != nil { http.Error(w, "500 - Internal Server Error", http.StatusInternalServerError) return @@ -76,7 +76,7 @@ func CheckTorii(w http.ResponseWriter, r *http.Request, reqData dataType.UserReq ConnectIP string Date string }{ - EdgeTag: config.Cfg.Server.NodeName, + EdgeTag: cfg.NodeName, ConnectIP: reqData.RemoteIP, Date: time.Now().Format("2006-01-02 15:04:05"), } diff --git a/main.go b/main.go index ed3c346..ac72194 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - _ "embed" "flag" "log" "os" @@ -16,9 +15,6 @@ import ( "time" ) -//go:embed config/default.yaml -var defaultConfigContent []byte - func main() { var basePath string flag.StringVar(&basePath, "prefix", "", "Config file base path") @@ -29,18 +25,21 @@ func main() { } // Load MainConfig - config.InitConfig(defaultConfigContent) + cfg, err := config.LoadMainConfig(basePath) + if err != nil { + log.Fatalf("Load config failed: %v", err) + } // Load rules - ruleSet, err := config.LoadRules(config.Cfg.Server.RulePath) + ruleSet, err := config.LoadRules(cfg.RulePath) if err != nil { log.Fatalf("Load rules failed: %v", err) } - log.Printf("Ready to start server on port %s", config.Cfg.Server.Port) + log.Printf("Ready to start server on port %s", cfg.Port) //set log file - defaultLogPath := filepath.Join(config.Cfg.Server.LogPath + "server_torii.log") + defaultLogPath := filepath.Join(cfg.LogPath + "server_torii.log") logFile, err := os.OpenFile(defaultLogPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { log.Fatalf("Failed to open log file: %v", err) @@ -53,7 +52,7 @@ func main() { }(logFile) log.SetOutput(logFile) - utils.InitLogx(config.Cfg.Server.LogPath) + utils.InitLogx(cfg.LogPath) //allocate shared memory sharedMem := &dataType.SharedMemory{ @@ -72,7 +71,7 @@ func main() { serverErr := make(chan error, 1) go func() { - serverErr <- server.StartServer(ruleSet, sharedMem) + serverErr <- server.StartServer(cfg, ruleSet, sharedMem) }() select {