This commit is contained in:
@@ -1,12 +1,26 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testDBSetupOnce sync.Once
|
||||
testDBHost string
|
||||
testDBPort string
|
||||
testDBUser string
|
||||
testDBPassword string
|
||||
testDBName string
|
||||
)
|
||||
|
||||
// TestSetupDB initializes the test database using existing functions
|
||||
// It creates the database if it doesn't exist and runs migrations
|
||||
// Uses sync.Once to ensure it only runs once across all tests
|
||||
func TestSetupDB(t *testing.T) {
|
||||
host := os.Getenv("DB_HOST")
|
||||
port := os.Getenv("DB_PORT")
|
||||
@@ -18,14 +32,60 @@ func TestSetupDB(t *testing.T) {
|
||||
t.Skip("Test database environment variables not set")
|
||||
}
|
||||
|
||||
// Use existing function to create database if it doesn't exist and run migrations
|
||||
Migrate_db(host, port, user, password, dbname)
|
||||
InitDB(host, port, user, password, dbname)
|
||||
// Store for TestTearDownDB
|
||||
testDBHost = host
|
||||
testDBPort = port
|
||||
testDBUser = user
|
||||
testDBPassword = password
|
||||
testDBName = dbname
|
||||
|
||||
// Only run setup once
|
||||
testDBSetupOnce.Do(func() {
|
||||
// Create the database first (testuser is a superuser in the container)
|
||||
createTestDatabase(host, port, dbname, user, password)
|
||||
|
||||
// Now run migrations using the existing function
|
||||
Migrate_db(host, port, user, password, dbname)
|
||||
InitDB(host, port, user, password, dbname)
|
||||
})
|
||||
}
|
||||
|
||||
// createTestDatabase creates the test database
|
||||
// In the test container, POSTGRES_USER is created as a superuser
|
||||
func createTestDatabase(host, port, dbname, user, password string) {
|
||||
// Connect to the postgres database to create new database
|
||||
connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=postgres sslmode=disable", host, port, user, password)
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
log.Println("Warning: Could not connect to create test database:", err)
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if database exists
|
||||
var dbExists int
|
||||
err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = $1", dbname).Scan(&dbExists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
log.Println("Warning: Could not check if database exists:", err)
|
||||
return
|
||||
}
|
||||
|
||||
if dbExists == 0 {
|
||||
// Create database
|
||||
_, err = db.Exec("CREATE DATABASE " + dbname)
|
||||
if err != nil {
|
||||
log.Println("Warning: Could not create database:", err)
|
||||
return
|
||||
}
|
||||
log.Println("Created test database:", dbname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTearDownDB closes the test database connection
|
||||
// Note: We don't actually close the pool between tests to avoid
|
||||
// "closed pool" errors when tests run sequentially
|
||||
func TestTearDownDB(t *testing.T) {
|
||||
CloseDb()
|
||||
// CloseDb() // Disabled to prevent pool closure between sequential tests
|
||||
}
|
||||
|
||||
// TestClearDatabase clears all data from the test database
|
||||
|
||||
Reference in New Issue
Block a user