diff --git a/backend/cmd/kubevoyage/main.go b/backend/cmd/kubevoyage/main.go index a4005ef..3ee8a75 100644 --- a/backend/cmd/kubevoyage/main.go +++ b/backend/cmd/kubevoyage/main.go @@ -1,13 +1,9 @@ package main import ( - "fmt" "github.com/B-Urb/KubeVoyage/internal/handlers" "github.com/B-Urb/KubeVoyage/internal/models" "github.com/rs/cors" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" "gorm.io/gorm" "log" "net/http" @@ -18,35 +14,19 @@ import ( var db *gorm.DB func main() { - // Read environment variables - dbType := os.Getenv("DB_TYPE") - dbHost := os.Getenv("DB_HOST") - dbPort := os.Getenv("DB_PORT") - dbUser := os.Getenv("DB_USER") - dbPassword := os.Getenv("DB_PASSWORD") - dbName := os.Getenv("DB_NAME") + app, err := NewApp() + if err != nil { + log.Fatalf("Failed to initialize app: %v", err) + } - var dsn string - var err error - var db *gorm.DB + app.Migrate() - switch dbType { - case "mysql": - dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbHost, dbPort, dbName) - db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) - case "postgres": - dsn = fmt.Sprintf("host=%s port=%s user=%s dbname=%s password=%s sslmode=disable", dbHost, dbPort, dbUser, dbName, dbPassword) - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) - case "sqlite": - dsn = dbName // For SQLite, dbName would be the path to the .db file - db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) - default: - log.Fatalf("Unsupported DB_TYPE: %s", dbType) - } + handler := setupServer(app) - if err != nil { - log.Fatalf("Failed to connect to database: %v", err) - } + log.Println("Starting server on :8080") + log.Fatal(http.ListenAndServe(":8080", handler)) +} +func setupServer(app *App) http.Handler { mux := http.NewServeMux() // Migrate the schema @@ -93,44 +73,11 @@ func main() { mux.HandleFunc("/api/request", func(w http.ResponseWriter, r *http.Request) { handlers.HandleRequestSite(w, r, db) }) - // Start the server on port 8081 - log.Println("Starting server on :8080") - - log.Fatal(http.ListenAndServe(":8080", handler)) - // ... setup your routes and start your server + handler := cors.Default().Handler(mux) + return handler } + func isAPIRoute(path string) bool { return len(path) >= 4 && path[0:4] == "/api" } -func generateTestData() { - // Insert test data for Users - users := []models.User{ - {Email: "user1@example.com", Password: "password1", Role: "admin"}, - {Email: "user2@example.com", Password: "password2", Role: "user"}, - {Email: "user3@example.com", Password: "password3", Role: "user"}, - } - for _, user := range users { - db.Create(&user) - } - - // Insert test data for Sites - sites := []models.Site{ - {URL: "https://site1.com"}, - {URL: "https://site2.com"}, - {URL: "https://site3.com"}, - } - for _, site := range sites { - db.Create(&site) - } - - // Insert test data for UserSite - userSites := []models.UserSite{ - {UserID: 1, SiteID: 1, State: "authorized"}, - {UserID: 2, SiteID: 2, State: "requested"}, - {UserID: 3, SiteID: 3, State: "authorized"}, - } - for _, userSite := range userSites { - db.Create(&userSite) - } -} diff --git a/backend/internal/app/app.go b/backend/internal/app/app.go new file mode 100644 index 0000000..b58cc3b --- /dev/null +++ b/backend/internal/app/app.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "github.com/B-Urb/KubeVoyage/internal/models" + "gorm.io/gorm" + "log" + "os" +) + +type App struct { + DB *gorm.DB + JWTKey []byte + BaseURL string +} + +func NewApp() (*App, error) { + db, err := initializeDatabase() + if err != nil { + return nil, fmt.Errorf("failed to initialize database: %v", err) + } + + jwtKey, err := getEnvOrError("JWT_SECRET_KEY") + if err != nil { + log.Fatalf("Error reading JWT_SECRET_KEY: %v", err) + } + + baseURL, err := getEnvOrError("BASE_URL") + if err != nil { + log.Fatalf("Error reading BASE_URL: %v", err) + } + + return &App{ + DB: db, + JWTKey: []byte(jwtKey), + BaseURL: baseURL, + }, nil +} + +func (app *App) Migrate() { + app.DB.AutoMigrate(&models.User{}, &models.Site{}, &models.UserSite{}) +} + +func getEnvOrError(key string) (string, error) { + value := os.Getenv(key) + if value == "" { + return "", fmt.Errorf("environment variable %s not set", key) + } + return value, nil +} diff --git a/backend/internal/database/database.go b/backend/internal/database/database.go new file mode 100644 index 0000000..6c964de --- /dev/null +++ b/backend/internal/database/database.go @@ -0,0 +1,74 @@ +package database + +import ( + "fmt" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "os" +) + +func InitializeDatabase() (*gorm.DB, error) { + // Read environment variables + dbType, err := getEnvOrError("DB_TYPE") + if err != nil { + return nil, err + } + + dbHost, err := getEnvOrError("DB_HOST") + if err != nil { + return nil, err + } + + dbPort, err := getEnvOrError("DB_PORT") + if err != nil { + return nil, err + } + + dbUser, err := getEnvOrError("DB_USER") + if err != nil { + return nil, err + } + + dbPassword, err := getEnvOrError("DB_PASSWORD") + if err != nil { + return nil, err + } + + dbName, err := getEnvOrError("DB_NAME") + if err != nil { + return nil, err + } + + var dsn string + var db *gorm.DB + + switch dbType { + case "mysql": + dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbHost, dbPort, dbName) + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) + case "postgres": + dsn = fmt.Sprintf("host=%s port=%s user=%s dbname=%s password=%s sslmode=disable", dbHost, dbPort, dbUser, dbName, dbPassword) + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case "sqlite": + dsn = dbName // For SQLite, dbName would be the path to the .db file + db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{}) + default: + return nil, fmt.Errorf("Unsupported DB_TYPE: %s", dbType) + } + + if err != nil { + return nil, fmt.Errorf("Failed to connect to database: %v", err) + } + + return db, nil +} + +func getEnvOrError(key string) (string, error) { + value := os.Getenv(key) + if value == "" { + return "", fmt.Errorf("Environment variable %s not set", key) + } + return value, nil +} diff --git a/backend/internal/handlers/auth.go b/backend/internal/handlers/auth.go index 28d5cd1..40a299e 100644 --- a/backend/internal/handlers/auth.go +++ b/backend/internal/handlers/auth.go @@ -8,9 +8,9 @@ import ( "errors" "fmt" "github.com/B-Urb/KubeVoyage/internal/models" - "github.com/dgrijalva/jwt-go" "golang.org/x/crypto/scrypt" "gorm.io/gorm" + "log" "net/http" "net/url" "time" @@ -18,7 +18,7 @@ import ( var jwtKey = []byte("your_secret_key") -func HandleLogin(w http.ResponseWriter, r *http.Request, db *gorm.DB) { +func (app *App) HandleLogin(w http.ResponseWriter, r *http.Request) { var inputUser models.User var dbUser models.User @@ -82,7 +82,7 @@ func HandleLogin(w http.ResponseWriter, r *http.Request, db *gorm.DB) { w.Write([]byte("Login successful")) } -func HandleRegister(w http.ResponseWriter, r *http.Request, db *gorm.DB) { +func (app *App) HandleRegister(w http.ResponseWriter, r *http.Request) { var user models.User // Parse the request body @@ -125,30 +125,27 @@ func HandleRegister(w http.ResponseWriter, r *http.Request, db *gorm.DB) { sendJSONSuccess(w, "", http.StatusCreated) } -func HandleAuthenticate(w http.ResponseWriter, r *http.Request, db *gorm.DB) { +func (app *App) HandleAuthenticate(w http.ResponseWriter, r *http.Request) { // 1. Extract the user's email from the session or JWT token. - userEmail, err := getUserEmailFromToken(r) + userEmail, err := app.getUserEmailFromToken(r) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + app.logError(w, "Failed to get user email from token", err, http.StatusUnauthorized) return } // 2. Extract the redirect parameter from the request to get the site URL. siteURL := r.Header.Get("X-Forwarded-Uri") if siteURL == "" { - http.Error(w, "Redirect URL missing", http.StatusBadRequest) - return + siteURL = r.URL.Query().Get("redirect") + if siteURL == "" { + app.logError(w, "Redirect URL missing from both header and URL parameter", nil, http.StatusBadRequest) + return + } } - //siteURL := r.URL.Query().Get("redirect") - //if siteURL == "" { - // http.Error(w, "Redirect URL missing", http.StatusBadRequest) - // return - //} - // 3. Query the database to check if the user has an "authorized" state for the given site. var userSite models.UserSite - err = db.Joins("JOIN users ON users.id = user_sites.user_id"). + err = app.DB.Joins("JOIN users ON users.id = user_sites.user_id"). Joins("JOIN sites ON sites.id = user_sites.site_id"). Where("users.email = ? AND sites.url = ? AND user_sites.state = ?", userEmail, siteURL, "authorized"). First(&userSite).Error @@ -159,13 +156,19 @@ func HandleAuthenticate(w http.ResponseWriter, r *http.Request, db *gorm.DB) { http.Redirect(w, r, "/request?redirect="+url.QueryEscape(siteURL), http.StatusSeeOther) return } - http.Error(w, "Database error", http.StatusInternalServerError) + app.logError(w, "Database error while checking user authorization", err, http.StatusInternalServerError) return } http.Redirect(w, r, siteURL, http.StatusSeeOther) +} - // If everything is fine, return a success message. - //w.Write([]byte("Access granted")) +func (app *App) logError(w http.ResponseWriter, message string, err error, statusCode int) { + logMessage := message + if err != nil { + logMessage = fmt.Sprintf("%s: %v", message, err) + } + log.Println(logMessage) + http.Error(w, message, statusCode) } func getUserEmailFromToken(r *http.Request) (string, error) { diff --git a/backend/internal/util/util.go b/backend/internal/util/util.go new file mode 100644 index 0000000..acd9a36 --- /dev/null +++ b/backend/internal/util/util.go @@ -0,0 +1,37 @@ +package main + +import ( + "github.com/B-Urb/KubeVoyage/internal/models" +) + +func generateTestData() { + // Insert test data for Users + users := []models.User{ + {Email: "user1@example.com", Password: "password1", Role: "admin"}, + {Email: "user2@example.com", Password: "password2", Role: "user"}, + {Email: "user3@example.com", Password: "password3", Role: "user"}, + } + for _, user := range users { + db.Create(&user) + } + + // Insert test data for Sites + sites := []models.Site{ + {URL: "https://site1.com"}, + {URL: "https://site2.com"}, + {URL: "https://site3.com"}, + } + for _, site := range sites { + db.Create(&site) + } + + // Insert test data for UserSite + userSites := []models.UserSite{ + {UserID: 1, SiteID: 1, State: "authorized"}, + {UserID: 2, SiteID: 2, State: "requested"}, + {UserID: 3, SiteID: 3, State: "authorized"}, + } + for _, userSite := range userSites { + db.Create(&userSite) + } +}