333 lines
8.2 KiB
Go
333 lines
8.2 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"lost-and-found/internal/models"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
var db *gorm.DB
|
|
|
|
// DatabaseConfig holds database connection configuration
|
|
type DatabaseConfig struct {
|
|
Host string
|
|
Port string
|
|
User string
|
|
Password string
|
|
DBName string
|
|
Charset string
|
|
ParseTime string
|
|
Loc string
|
|
}
|
|
|
|
// GetDatabaseConfig returns database configuration from environment
|
|
func GetDatabaseConfig() DatabaseConfig {
|
|
return DatabaseConfig{
|
|
Host: getEnv("DB_HOST", "localhost"),
|
|
Port: getEnv("DB_PORT", "3306"),
|
|
User: getEnv("DB_USER", "root"),
|
|
Password: getEnv("DB_PASSWORD", ""),
|
|
DBName: getEnv("DB_NAME", "lost_and_found"),
|
|
Charset: getEnv("DB_CHARSET", "utf8mb4"),
|
|
ParseTime: getEnv("DB_PARSE_TIME", "True"),
|
|
Loc: getEnv("DB_LOC", "Local"),
|
|
}
|
|
}
|
|
|
|
// InitDB initializes database connection
|
|
func InitDB() error {
|
|
config := GetDatabaseConfig()
|
|
|
|
// Step 1: Connect to MySQL without specifying database (to create if not exists)
|
|
if err := ensureDatabaseExists(config); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Step 2: Connect to the specific database
|
|
dsn := fmt.Sprintf(
|
|
"%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=%s&loc=%s&multiStatements=true",
|
|
config.User,
|
|
config.Password,
|
|
config.Host,
|
|
config.Port,
|
|
config.DBName,
|
|
config.Charset,
|
|
config.ParseTime,
|
|
config.Loc,
|
|
)
|
|
|
|
// Configure GORM logger
|
|
gormLogger := logger.Default
|
|
if IsDevelopment() {
|
|
gormLogger = logger.Default.LogMode(logger.Info)
|
|
} else {
|
|
gormLogger = logger.Default.LogMode(logger.Error)
|
|
}
|
|
|
|
// Open database connection
|
|
var err error
|
|
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: gormLogger,
|
|
NowFunc: func() time.Time {
|
|
return time.Now().Local()
|
|
},
|
|
})
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
// Get underlying SQL database
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database instance: %w", err)
|
|
}
|
|
|
|
// Set connection pool settings
|
|
sqlDB.SetMaxIdleConns(10)
|
|
sqlDB.SetMaxOpenConns(100)
|
|
sqlDB.SetConnMaxLifetime(time.Hour)
|
|
|
|
// Test connection
|
|
if err := sqlDB.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
log.Println("✅ Database connected successfully")
|
|
|
|
return nil
|
|
}
|
|
|
|
// ensureDatabaseExists checks if database exists, creates it if not
|
|
func ensureDatabaseExists(config DatabaseConfig) error {
|
|
// Connect to MySQL without specifying a database
|
|
dsn := fmt.Sprintf(
|
|
"%s:%s@tcp(%s:%s)/?charset=%s&parseTime=%s&loc=%s",
|
|
config.User,
|
|
config.Password,
|
|
config.Host,
|
|
config.Port,
|
|
config.Charset,
|
|
config.ParseTime,
|
|
config.Loc,
|
|
)
|
|
|
|
tempDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to MySQL server: %w", err)
|
|
}
|
|
|
|
log.Printf("🔍 Checking if database '%s' exists...", config.DBName)
|
|
|
|
// Check if database exists
|
|
var dbExists int64
|
|
if err := tempDB.Raw(
|
|
"SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?",
|
|
config.DBName,
|
|
).Scan(&dbExists).Error; err != nil {
|
|
return fmt.Errorf("failed to check database existence: %w", err)
|
|
}
|
|
|
|
if dbExists == 0 {
|
|
log.Printf("📝 Creating database '%s'...", config.DBName)
|
|
createSQL := fmt.Sprintf(
|
|
"CREATE DATABASE IF NOT EXISTS %s CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
|
|
config.DBName,
|
|
)
|
|
if err := tempDB.Exec(createSQL).Error; err != nil {
|
|
return fmt.Errorf("failed to create database: %w", err)
|
|
}
|
|
log.Printf("✅ Database '%s' created successfully", config.DBName)
|
|
} else {
|
|
log.Printf("✅ Database '%s' already exists", config.DBName)
|
|
}
|
|
|
|
// Close temporary connection
|
|
sqlDB, _ := tempDB.DB()
|
|
sqlDB.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetDB returns the database instance
|
|
func GetDB() *gorm.DB {
|
|
return db
|
|
}
|
|
|
|
// RunMigrations runs database migrations from SQL files
|
|
func RunMigrations(db *gorm.DB) error {
|
|
log.Println("📊 Starting database migrations...")
|
|
|
|
// Check if tables already exist
|
|
if db.Migrator().HasTable(&models.Role{}) {
|
|
log.Println("✅ Database tables already exist, skipping migration")
|
|
return nil
|
|
}
|
|
|
|
log.Println("📋 Tables not found, running migration scripts...")
|
|
|
|
// Step 1: Run schema.sql
|
|
if err := runSQLFile(db, "database/schema.sql"); err != nil {
|
|
return fmt.Errorf("❌ Failed to run schema.sql: %w", err)
|
|
}
|
|
log.Println("✅ Schema created successfully")
|
|
|
|
// Step 2: Run seed.sql
|
|
if err := runSQLFile(db, "database/seed.sql"); err != nil {
|
|
return fmt.Errorf("❌ Failed to run seed.sql: %w", err)
|
|
}
|
|
log.Println("✅ Seed data inserted successfully")
|
|
|
|
// Step 3: Run enhancement.sql (optional - for triggers, procedures, etc)
|
|
if err := runSQLFile(db, "database/enhancement.sql"); err != nil {
|
|
log.Printf("⚠️ Warning: Failed to run enhancement.sql: %v", err)
|
|
log.Println("💡 Enhancement features (triggers, procedures) may not be available")
|
|
} else {
|
|
log.Println("✅ Enhancement features loaded successfully")
|
|
}
|
|
|
|
log.Println("🎉 Database migration completed!")
|
|
log.Println("🔧 Default admin: admin@lostandfound.com / password123")
|
|
|
|
return nil
|
|
}
|
|
|
|
// runSQLFile executes SQL from file
|
|
func runSQLFile(db *gorm.DB, filepath string) error {
|
|
// Check if file exists
|
|
if _, err := os.Stat(filepath); os.IsNotExist(err) {
|
|
return fmt.Errorf("file not found: %s", filepath)
|
|
}
|
|
|
|
log.Printf("📄 Reading SQL file: %s", filepath)
|
|
|
|
// Read file
|
|
content, err := os.ReadFile(filepath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read file: %w", err)
|
|
}
|
|
|
|
// Split SQL content by delimiter
|
|
sqlContent := string(content)
|
|
|
|
// Remove comments and empty lines
|
|
sqlContent = removeComments(sqlContent)
|
|
|
|
// Split by DELIMITER if exists (for procedures/triggers)
|
|
if strings.Contains(sqlContent, "DELIMITER") {
|
|
return executeSQLWithDelimiter(db, sqlContent)
|
|
}
|
|
|
|
// Execute SQL normally
|
|
if err := db.Exec(sqlContent).Error; err != nil {
|
|
return fmt.Errorf("failed to execute SQL: %w", err)
|
|
}
|
|
|
|
log.Printf("✅ SQL file executed: %s", filepath)
|
|
return nil
|
|
}
|
|
|
|
// executeSQLWithDelimiter handles SQL with custom delimiters (for procedures/triggers)
|
|
func executeSQLWithDelimiter(db *gorm.DB, content string) error {
|
|
// Split by DELIMITER changes
|
|
parts := strings.Split(content, "DELIMITER")
|
|
|
|
for i, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if part == "" || part == "$$" || part == ";" {
|
|
continue
|
|
}
|
|
|
|
// Remove the delimiter declaration line
|
|
lines := strings.Split(part, "\n")
|
|
if len(lines) > 0 && (strings.HasPrefix(lines[0], "$$") || strings.HasPrefix(lines[0], ";")) {
|
|
lines = lines[1:]
|
|
}
|
|
part = strings.Join(lines, "\n")
|
|
|
|
// Split by custom delimiter ($$)
|
|
if i%2 == 1 { // Odd parts use $$ delimiter
|
|
statements := strings.Split(part, "$$")
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" || stmt == ";" {
|
|
continue
|
|
}
|
|
if err := db.Exec(stmt).Error; err != nil {
|
|
log.Printf("⚠️ Warning executing statement: %v", err)
|
|
// Don't fail on enhancement errors (triggers, procedures)
|
|
// These might fail if they already exist or MySQL version issues
|
|
}
|
|
}
|
|
} else { // Even parts use ; delimiter
|
|
statements := strings.Split(part, ";")
|
|
for _, stmt := range statements {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" {
|
|
continue
|
|
}
|
|
if err := db.Exec(stmt).Error; err != nil {
|
|
return fmt.Errorf("failed to execute statement: %w", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// removeComments removes SQL comments from content
|
|
func removeComments(sql string) string {
|
|
lines := strings.Split(sql, "\n")
|
|
var cleaned []string
|
|
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
|
|
// Skip empty lines
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// Skip single-line comments
|
|
if strings.HasPrefix(line, "--") || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
// Keep the line
|
|
cleaned = append(cleaned, line)
|
|
}
|
|
|
|
return strings.Join(cleaned, "\n")
|
|
}
|
|
|
|
// CloseDB closes database connection
|
|
func CloseDB() error {
|
|
if db != nil {
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Helper function to get environment variable with default value
|
|
func getEnv(key, defaultValue string) string {
|
|
value := os.Getenv(key)
|
|
if value == "" {
|
|
return defaultValue
|
|
}
|
|
return value
|
|
} |