2025-11-23 22:49:46 +07:00

333 lines
8.2 KiB
Go

package config
import (
"fmt"
"log"
"lost-and-found/internal/models"
"os"
"strings"
"time"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var db *gorm.DB
// DatabaseConfig holds database connection configuration
type DatabaseConfig struct {
Host string
Port string
User string
Password string
DBName string
Charset string
ParseTime string
Loc string
}
// GetDatabaseConfig returns database configuration from environment
func GetDatabaseConfig() DatabaseConfig {
return DatabaseConfig{
Host: getEnv("DB_HOST", "localhost"),
Port: getEnv("DB_PORT", "3306"),
User: getEnv("DB_USER", "root"),
Password: getEnv("DB_PASSWORD", ""),
DBName: getEnv("DB_NAME", "lost_and_found"),
Charset: getEnv("DB_CHARSET", "utf8mb4"),
ParseTime: getEnv("DB_PARSE_TIME", "True"),
Loc: getEnv("DB_LOC", "Local"),
}
}
// InitDB initializes database connection
func InitDB() error {
config := GetDatabaseConfig()
// Step 1: Connect to MySQL without specifying database (to create if not exists)
if err := ensureDatabaseExists(config); err != nil {
return err
}
// Step 2: Connect to the specific database
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=%s&loc=%s&multiStatements=true",
config.User,
config.Password,
config.Host,
config.Port,
config.DBName,
config.Charset,
config.ParseTime,
config.Loc,
)
// Configure GORM logger
gormLogger := logger.Default
if IsDevelopment() {
gormLogger = logger.Default.LogMode(logger.Info)
} else {
gormLogger = logger.Default.LogMode(logger.Error)
}
// Open database connection
var err error
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: gormLogger,
NowFunc: func() time.Time {
return time.Now().Local()
},
})
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// Get underlying SQL database
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("failed to get database instance: %w", err)
}
// Set connection pool settings
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(100)
sqlDB.SetConnMaxLifetime(time.Hour)
// Test connection
if err := sqlDB.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
log.Println("✅ Database connected successfully")
return nil
}
// ensureDatabaseExists checks if database exists, creates it if not
func ensureDatabaseExists(config DatabaseConfig) error {
// Connect to MySQL without specifying a database
dsn := fmt.Sprintf(
"%s:%s@tcp(%s:%s)/?charset=%s&parseTime=%s&loc=%s",
config.User,
config.Password,
config.Host,
config.Port,
config.Charset,
config.ParseTime,
config.Loc,
)
tempDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return fmt.Errorf("failed to connect to MySQL server: %w", err)
}
log.Printf("🔍 Checking if database '%s' exists...", config.DBName)
// Check if database exists
var dbExists int64
if err := tempDB.Raw(
"SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
config.DBName,
).Scan(&dbExists).Error; err != nil {
return fmt.Errorf("failed to check database existence: %w", err)
}
if dbExists == 0 {
log.Printf("📝 Creating database '%s'...", config.DBName)
createSQL := fmt.Sprintf(
"CREATE DATABASE IF NOT EXISTS %s CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
config.DBName,
)
if err := tempDB.Exec(createSQL).Error; err != nil {
return fmt.Errorf("failed to create database: %w", err)
}
log.Printf("✅ Database '%s' created successfully", config.DBName)
} else {
log.Printf("✅ Database '%s' already exists", config.DBName)
}
// Close temporary connection
sqlDB, _ := tempDB.DB()
sqlDB.Close()
return nil
}
// GetDB returns the database instance
func GetDB() *gorm.DB {
return db
}
// RunMigrations runs database migrations from SQL files
func RunMigrations(db *gorm.DB) error {
log.Println("📊 Starting database migrations...")
// Check if tables already exist
if db.Migrator().HasTable(&models.Role{}) {
log.Println("✅ Database tables already exist, skipping migration")
return nil
}
log.Println("📋 Tables not found, running migration scripts...")
// Step 1: Run schema.sql
if err := runSQLFile(db, "database/schema.sql"); err != nil {
return fmt.Errorf("❌ Failed to run schema.sql: %w", err)
}
log.Println("✅ Schema created successfully")
// Step 2: Run seed.sql
if err := runSQLFile(db, "database/seed.sql"); err != nil {
return fmt.Errorf("❌ Failed to run seed.sql: %w", err)
}
log.Println("✅ Seed data inserted successfully")
// Step 3: Run enhancement.sql (optional - for triggers, procedures, etc)
if err := runSQLFile(db, "database/enhancement.sql"); err != nil {
log.Printf("⚠️ Warning: Failed to run enhancement.sql: %v", err)
log.Println("💡 Enhancement features (triggers, procedures) may not be available")
} else {
log.Println("✅ Enhancement features loaded successfully")
}
log.Println("🎉 Database migration completed!")
log.Println("🔧 Default admin: admin@lostandfound.com / password123")
return nil
}
// runSQLFile executes SQL from file
func runSQLFile(db *gorm.DB, filepath string) error {
// Check if file exists
if _, err := os.Stat(filepath); os.IsNotExist(err) {
return fmt.Errorf("file not found: %s", filepath)
}
log.Printf("📄 Reading SQL file: %s", filepath)
// Read file
content, err := os.ReadFile(filepath)
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
// Split SQL content by delimiter
sqlContent := string(content)
// Remove comments and empty lines
sqlContent = removeComments(sqlContent)
// Split by DELIMITER if exists (for procedures/triggers)
if strings.Contains(sqlContent, "DELIMITER") {
return executeSQLWithDelimiter(db, sqlContent)
}
// Execute SQL normally
if err := db.Exec(sqlContent).Error; err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
log.Printf("✅ SQL file executed: %s", filepath)
return nil
}
// executeSQLWithDelimiter handles SQL with custom delimiters (for procedures/triggers)
func executeSQLWithDelimiter(db *gorm.DB, content string) error {
// Split by DELIMITER changes
parts := strings.Split(content, "DELIMITER")
for i, part := range parts {
part = strings.TrimSpace(part)
if part == "" || part == "$$" || part == ";" {
continue
}
// Remove the delimiter declaration line
lines := strings.Split(part, "\n")
if len(lines) > 0 && (strings.HasPrefix(lines[0], "$$") || strings.HasPrefix(lines[0], ";")) {
lines = lines[1:]
}
part = strings.Join(lines, "\n")
// Split by custom delimiter ($$)
if i%2 == 1 { // Odd parts use $$ delimiter
statements := strings.Split(part, "$$")
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" || stmt == ";" {
continue
}
if err := db.Exec(stmt).Error; err != nil {
log.Printf("⚠️ Warning executing statement: %v", err)
// Don't fail on enhancement errors (triggers, procedures)
// These might fail if they already exist or MySQL version issues
}
}
} else { // Even parts use ; delimiter
statements := strings.Split(part, ";")
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
}
if err := db.Exec(stmt).Error; err != nil {
return fmt.Errorf("failed to execute statement: %w", err)
}
}
}
}
return nil
}
// removeComments removes SQL comments from content
func removeComments(sql string) string {
lines := strings.Split(sql, "\n")
var cleaned []string
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip empty lines
if line == "" {
continue
}
// Skip single-line comments
if strings.HasPrefix(line, "--") || strings.HasPrefix(line, "#") {
continue
}
// Keep the line
cleaned = append(cleaned, line)
}
return strings.Join(cleaned, "\n")
}
// CloseDB closes database connection
func CloseDB() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}
// Helper function to get environment variable with default value
func getEnv(key, defaultValue string) string {
value := os.Getenv(key)
if value == "" {
return defaultValue
}
return value
}