package config import ( "fmt" "log" "lost-and-found/internal/models" "os" "strings" "time" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" ) var db *gorm.DB // DatabaseConfig holds database connection configuration type DatabaseConfig struct { Host string Port string User string Password string DBName string Charset string ParseTime string Loc string } // GetDatabaseConfig returns database configuration from environment func GetDatabaseConfig() DatabaseConfig { return DatabaseConfig{ Host: getEnv("DB_HOST", "localhost"), Port: getEnv("DB_PORT", "3306"), User: getEnv("DB_USER", "root"), Password: getEnv("DB_PASSWORD", ""), DBName: getEnv("DB_NAME", "lost_and_found"), Charset: getEnv("DB_CHARSET", "utf8mb4"), ParseTime: getEnv("DB_PARSE_TIME", "True"), Loc: getEnv("DB_LOC", "Local"), } } // InitDB initializes database connection func InitDB() error { config := GetDatabaseConfig() // Step 1: Connect to MySQL without specifying database (to create if not exists) if err := ensureDatabaseExists(config); err != nil { return err } // Step 2: Connect to the specific database dsn := fmt.Sprintf( "%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=%s&loc=%s&multiStatements=true", config.User, config.Password, config.Host, config.Port, config.DBName, config.Charset, config.ParseTime, config.Loc, ) // Configure GORM logger gormLogger := logger.Default if IsDevelopment() { gormLogger = logger.Default.LogMode(logger.Info) } else { gormLogger = logger.Default.LogMode(logger.Error) } // Open database connection var err error db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: gormLogger, NowFunc: func() time.Time { return time.Now().Local() }, }) if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } // Get underlying SQL database sqlDB, err := db.DB() if err != nil { return fmt.Errorf("failed to get database instance: %w", err) } // Set connection pool settings sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(100) sqlDB.SetConnMaxLifetime(time.Hour) // Test connection if err := sqlDB.Ping(); err != nil { return fmt.Errorf("failed to ping database: %w", err) } log.Println("✅ Database connected successfully") return nil } // ensureDatabaseExists checks if database exists, creates it if not func ensureDatabaseExists(config DatabaseConfig) error { // Connect to MySQL without specifying a database dsn := fmt.Sprintf( "%s:%s@tcp(%s:%s)/?charset=%s&parseTime=%s&loc=%s", config.User, config.Password, config.Host, config.Port, config.Charset, config.ParseTime, config.Loc, ) tempDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return fmt.Errorf("failed to connect to MySQL server: %w", err) } log.Printf("🔍 Checking if database '%s' exists...", config.DBName) // Check if database exists var dbExists int64 if err := tempDB.Raw( "SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", config.DBName, ).Scan(&dbExists).Error; err != nil { return fmt.Errorf("failed to check database existence: %w", err) } if dbExists == 0 { log.Printf("📝 Creating database '%s'...", config.DBName) createSQL := fmt.Sprintf( "CREATE DATABASE IF NOT EXISTS %s CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci", config.DBName, ) if err := tempDB.Exec(createSQL).Error; err != nil { return fmt.Errorf("failed to create database: %w", err) } log.Printf("✅ Database '%s' created successfully", config.DBName) } else { log.Printf("✅ Database '%s' already exists", config.DBName) } // Close temporary connection sqlDB, _ := tempDB.DB() sqlDB.Close() return nil } // GetDB returns the database instance func GetDB() *gorm.DB { return db } // RunMigrations runs database migrations from SQL files func RunMigrations(db *gorm.DB) error { log.Println("📊 Starting database migrations...") // Check if tables already exist if db.Migrator().HasTable(&models.Role{}) { log.Println("✅ Database tables already exist, skipping migration") return nil } log.Println("📋 Tables not found, running migration scripts...") // Step 1: Run schema.sql if err := runSQLFile(db, "database/schema.sql"); err != nil { return fmt.Errorf("❌ Failed to run schema.sql: %w", err) } log.Println("✅ Schema created successfully") // Step 2: Run seed.sql if err := runSQLFile(db, "database/seed.sql"); err != nil { return fmt.Errorf("❌ Failed to run seed.sql: %w", err) } log.Println("✅ Seed data inserted successfully") // Step 3: Run enhancement.sql (optional - for triggers, procedures, etc) if err := runSQLFile(db, "database/enhancement.sql"); err != nil { log.Printf("⚠️ Warning: Failed to run enhancement.sql: %v", err) log.Println("💡 Enhancement features (triggers, procedures) may not be available") } else { log.Println("✅ Enhancement features loaded successfully") } log.Println("🎉 Database migration completed!") log.Println("🔧 Default admin: admin@lostandfound.com / password123") return nil } // runSQLFile executes SQL from file func runSQLFile(db *gorm.DB, filepath string) error { // Check if file exists if _, err := os.Stat(filepath); os.IsNotExist(err) { return fmt.Errorf("file not found: %s", filepath) } log.Printf("📄 Reading SQL file: %s", filepath) // Read file content, err := os.ReadFile(filepath) if err != nil { return fmt.Errorf("failed to read file: %w", err) } // Split SQL content by delimiter sqlContent := string(content) // Remove comments and empty lines sqlContent = removeComments(sqlContent) // Split by DELIMITER if exists (for procedures/triggers) if strings.Contains(sqlContent, "DELIMITER") { return executeSQLWithDelimiter(db, sqlContent) } // Execute SQL normally if err := db.Exec(sqlContent).Error; err != nil { return fmt.Errorf("failed to execute SQL: %w", err) } log.Printf("✅ SQL file executed: %s", filepath) return nil } // executeSQLWithDelimiter handles SQL with custom delimiters (for procedures/triggers) func executeSQLWithDelimiter(db *gorm.DB, content string) error { // Split by DELIMITER changes parts := strings.Split(content, "DELIMITER") for i, part := range parts { part = strings.TrimSpace(part) if part == "" || part == "$$" || part == ";" { continue } // Remove the delimiter declaration line lines := strings.Split(part, "\n") if len(lines) > 0 && (strings.HasPrefix(lines[0], "$$") || strings.HasPrefix(lines[0], ";")) { lines = lines[1:] } part = strings.Join(lines, "\n") // Split by custom delimiter ($$) if i%2 == 1 { // Odd parts use $$ delimiter statements := strings.Split(part, "$$") for _, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" || stmt == ";" { continue } if err := db.Exec(stmt).Error; err != nil { log.Printf("⚠️ Warning executing statement: %v", err) // Don't fail on enhancement errors (triggers, procedures) // These might fail if they already exist or MySQL version issues } } } else { // Even parts use ; delimiter statements := strings.Split(part, ";") for _, stmt := range statements { stmt = strings.TrimSpace(stmt) if stmt == "" { continue } if err := db.Exec(stmt).Error; err != nil { return fmt.Errorf("failed to execute statement: %w", err) } } } } return nil } // removeComments removes SQL comments from content func removeComments(sql string) string { lines := strings.Split(sql, "\n") var cleaned []string for _, line := range lines { line = strings.TrimSpace(line) // Skip empty lines if line == "" { continue } // Skip single-line comments if strings.HasPrefix(line, "--") || strings.HasPrefix(line, "#") { continue } // Keep the line cleaned = append(cleaned, line) } return strings.Join(cleaned, "\n") } // CloseDB closes database connection func CloseDB() error { if db != nil { sqlDB, err := db.DB() if err != nil { return err } return sqlDB.Close() } return nil } // Helper function to get environment variable with default value func getEnv(key, defaultValue string) string { value := os.Getenv(key) if value == "" { return defaultValue } return value }