Files
MusicServer/internal/db/test_helpers.go
T
Sansan c63202242b feat: Complete DI cleanup - migrate test helpers to Database struct
- Update internal/db/test_helpers.go to use Database struct instead of globals
- Update internal/server/test_helpers.go to use TestDatabase.Pool
- Add TODO comment to old Dbpool/Ctx globals in dbHelper.go
- Remove db.Testf() usage from production code (kept for deprecated /dbtest endpoint)

Generated by Mistral Vibe.
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-01 20:06:47 +02:00

135 lines
3.7 KiB
Go

package db
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"sync"
"testing"
)
var (
testDBSetupOnce sync.Once
testDBHost string
testDBPort string
testDBUser string
testDBPassword string
testDBName string
// TestDatabase is the database instance for tests
TestDatabase *Database
)
// TestSetupDB initializes the test database using existing functions
// It creates the database if it doesn't exist and runs migrations
// Uses sync.Once to ensure it only runs once across all tests
func TestSetupDB(t *testing.T) {
host := os.Getenv("DB_HOST")
port := os.Getenv("DB_PORT")
user := os.Getenv("DB_USERNAME")
password := os.Getenv("DB_PASSWORD")
dbname := os.Getenv("DB_NAME")
if host == "" || port == "" || user == "" || password == "" || dbname == "" {
t.Skip("Test database environment variables not set")
}
// Store for TestTearDownDB
testDBHost = host
testDBPort = port
testDBUser = user
testDBPassword = password
testDBName = dbname
// Only run setup once
testDBSetupOnce.Do(func() {
// Create the database first (testuser is a superuser in the container)
createTestDatabase(host, port, dbname, user, password)
// Create database instance and run migrations
var err error
TestDatabase, err = NewDatabase(host, port, user, password, dbname)
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
// Run migrations
if err := TestDatabase.RunMigrations(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
})
}
// createTestDatabase creates the test database
// In the test container, POSTGRES_USER is created as a superuser
func createTestDatabase(host, port, dbname, user, password string) {
// Connect to the postgres database to create new database
connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=postgres sslmode=disable", host, port, user, password)
db, err := sql.Open("postgres", connStr)
if err != nil {
log.Println("Warning: Could not connect to create test database:", err)
return
}
defer db.Close()
// Check if database exists
var dbExists int
err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", dbname).Scan(&dbExists)
if err != nil && err != sql.ErrNoRows {
log.Println("Warning: Could not check if database exists:", err)
return
}
if dbExists == 0 {
// Create database
_, err = db.Exec("CREATE DATABASE " + dbname)
if err != nil {
log.Println("Warning: Could not create database:", err)
return
}
log.Println("Created test database:", dbname)
}
}
// TestTearDownDB closes the test database connection
// Note: We don't actually close the pool between tests to avoid
// "closed pool" errors when tests run sequentially
func TestTearDownDB(t *testing.T) {
// CloseDb() // Disabled to prevent pool closure between sequential tests
if TestDatabase != nil {
TestDatabase.Close()
TestDatabase = nil
}
}
// TestClearDatabase clears all data from the test database
// Useful for running tests with a clean slate
func TestClearDatabase(t *testing.T) {
if TestDatabase == nil || TestDatabase.Pool == nil {
t.Skip("Database not initialized")
}
// Clear all tables in reverse order to respect foreign keys
// Note: This assumes the tables exist and have the expected structure
tables := []string{
"song_list",
"song",
"game",
}
ctx := context.Background()
for _, table := range tables {
_, err := TestDatabase.Pool.Exec(ctx, "TRUNCATE TABLE "+table+" CASCADE")
if err != nil {
t.Logf("Failed to truncate table %s: %v", table, err)
}
}
// Reset sequences
_, err := TestDatabase.Pool.Exec(ctx, "SELECT setval('game_id_seq', 1, false)")
if err != nil {
t.Logf("Failed to reset game_id_seq: %v", err)
}
}