From 7219420d8ae3ec9c312c6e0cf07ea3dd329e4090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thorben=20G=C3=BCnther?= Date: Thu, 20 Oct 2022 16:06:58 +0200 Subject: [PATCH] database: Remove dependency on global variable Those should be generally avoided. Instead add the database connection to the Resolver struct. --- database/postgres.go | 39 ++++++++++++++++++++------------------- database/queue.go | 4 ++-- graph/resolver.go | 6 +++++- graph/schema.resolvers.go | 36 ++++++++++++++++++------------------ server.go | 10 ++++++---- 5 files changed, 51 insertions(+), 44 deletions(-) diff --git a/database/postgres.go b/database/postgres.go index 6b4965e..b5e4516 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -3,7 +3,6 @@ package database import ( "database/sql" "errors" - "log" "git.xenrox.net/~xenrox/10man-api/config" // Import driver @@ -11,45 +10,47 @@ import ( ) // DB is the postgres connection -var DB *sql.DB +type DB struct { + DB *sql.DB +} // Open opens a postgres database connection -func Open() error { - db, err := sql.Open("postgres", config.ConnectionString) +func Open() (*DB, error) { + database, err := sql.Open("postgres", config.ConnectionString) if err != nil { - return err + return nil, err } - DB = db + if err = database.Ping(); err != nil { + return nil, err + } - err = migrate() - if err != nil { - return err + db := &DB{DB: database} + + if err := db.migrate(); err != nil { + return nil, err } - return nil + return db, nil } // Close terminates the postgres connection -func Close() { - err := DB.Close() - if err != nil { - log.Println("Failed to close database.") - } +func (db *DB) Close() error { + return db.DB.Close() } // migrate initializes the database and handles migrations -func migrate() error { - _, err := DB.Exec(versionTable) +func (db *DB) migrate() error { + _, err := db.DB.Exec(versionTable) if err != nil { return err } var version int query := `SELECT id FROM "Version"` - err = DB.QueryRow(query).Scan(&version) + err = db.DB.QueryRow(query).Scan(&version) if errors.Is(err, sql.ErrNoRows) { - _, err = DB.Exec(initialSchema) + _, err = db.DB.Exec(initialSchema) if err != nil { return err } diff --git a/database/queue.go b/database/queue.go index ef69cf0..6070b7d 100644 --- a/database/queue.go +++ b/database/queue.go @@ -1,10 +1,10 @@ package database // PlayersInQueue returns the number of players in the queue -func PlayersInQueue() (int, error) { +func (db *DB) PlayersInQueue() (int, error) { var players int query := `SELECT COUNT(*) FILTER (WHERE queueing) FROM "User"` - err := DB.QueryRow(query).Scan(&players) + err := db.DB.QueryRow(query).Scan(&players) if err != nil { return -1, err } diff --git a/graph/resolver.go b/graph/resolver.go index a25c09c..3f25155 100644 --- a/graph/resolver.go +++ b/graph/resolver.go @@ -1,7 +1,11 @@ package graph +import "git.xenrox.net/~xenrox/10man-api/database" + // This file will not be regenerated automatically. // // It serves as dependency injection for your app, add any dependencies you require here. -type Resolver struct{} +type Resolver struct { + DB *database.DB +} diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 1b1f155..109eb9a 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -61,7 +61,7 @@ func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) query := ` INSERT INTO "User" (steam_id, teamspeak_id, elo, admin, avatar, name) VALUES ($1, $2, $3, $4, $5, $6)` - _, err = database.DB.Exec(query, input.SteamID, input.TeamspeakID, elo.Elo, + _, err = r.DB.DB.Exec(query, input.SteamID, input.TeamspeakID, elo.Elo, isAdmin, input.Avatar, input.Name) if err != nil { return "", database.CheckErrorCode(err) @@ -72,14 +72,14 @@ func (r *mutationResolver) CreateUser(ctx context.Context, input model.NewUser) func (r *mutationResolver) DeleteUser(ctx context.Context, id int) (string, error) { query := `DELETE FROM "User" WHERE id = $1` - _, err := database.DB.Exec(query, id) + _, err := r.DB.DB.Exec(query, id) return "Deleted", err } func (r *mutationResolver) UpdateUser(ctx context.Context, id int, input model.UserInput) (*model.User, error) { var user model.User - tx, err := database.DB.Begin() + tx, err := r.DB.DB.Begin() if err != nil { return nil, err } @@ -148,7 +148,7 @@ func (r *mutationResolver) UpdateUser(ctx context.Context, id int, input model.U } func (r *mutationResolver) StartQueue(ctx context.Context, teamspeakID string) (int, error) { - players, err := database.PlayersInQueue() + players, err := r.DB.PlayersInQueue() if err != nil { return players, err } @@ -161,7 +161,7 @@ func (r *mutationResolver) StartQueue(ctx context.Context, teamspeakID string) ( SELECT id FROM "Match" WHERE status = 'ongoing'` - err := database.DB.QueryRow(query).Scan(&id) + err := r.DB.DB.QueryRow(query).Scan(&id) if err == nil { return players, errors.New("cannot queue, match still ongoing") } else if err != nil && !errors.Is(err, sql.ErrNoRows) { @@ -177,7 +177,7 @@ func (r *mutationResolver) StartQueue(ctx context.Context, teamspeakID string) ( UPDATE "User" SET queueing = true WHERE teamspeak_id = $1 AND NOT queueing` - result, err := database.DB.Exec(query, teamspeakID) + result, err := r.DB.DB.Exec(query, teamspeakID) if err != nil { return players, fmt.Errorf("QueueFail: %w", err) } @@ -194,7 +194,7 @@ func (r *mutationResolver) StartQueue(ctx context.Context, teamspeakID string) ( SELECT queueing FROM "User" WHERE teamspeak_id = $1` - err = database.DB.QueryRow(query, teamspeakID).Scan(&queueing) + err = r.DB.DB.QueryRow(query, teamspeakID).Scan(&queueing) if err != nil { return players, database.CheckErrorCode(err) } @@ -215,7 +215,7 @@ func (r *mutationResolver) CancelQueue(ctx context.Context, teamspeakID string) UPDATE "User" SET queueing = false WHERE teamspeak_id = $1 AND queueing` - result, err := database.DB.Exec(query, teamspeakID) + result, err := r.DB.DB.Exec(query, teamspeakID) if err != nil { return -1, err } @@ -229,7 +229,7 @@ func (r *mutationResolver) CancelQueue(ctx context.Context, teamspeakID string) return -1, errors.New("user not in queue") } - players, err := database.PlayersInQueue() + players, err := r.DB.PlayersInQueue() return players, err } @@ -239,7 +239,7 @@ func (r *mutationResolver) CreateMatch(ctx context.Context) (int, error) { FROM "User" WHERE queueing ` - rows, err := database.DB.Query(query) + rows, err := r.DB.DB.Query(query) if err != nil { return -1, err } @@ -261,7 +261,7 @@ func (r *mutationResolver) CreateMatch(ctx context.Context) (int, error) { return -1, errors.New("queue is not full") } - tx, err := database.DB.Begin() + tx, err := r.DB.DB.Begin() if err != nil { return -1, err } @@ -329,7 +329,7 @@ func (r *mutationResolver) CancelMatch(ctx context.Context) (string, error) { UPDATE "Match" SET status = 'cancelled' WHERE status = 'ongoing'` - result, err := database.DB.Exec(query) + result, err := r.DB.DB.Exec(query) if err != nil { return "CancelFail", err } @@ -353,7 +353,7 @@ func (r *mutationResolver) FinishMatch(ctx context.Context, winner string) (stri SELECT t1, t2, elo1, elo2 FROM "Match" WHERE status = 'ongoing'` - err := database.DB.QueryRow(query). + err := r.DB.DB.QueryRow(query). Scan(&team1.ID, &team2.ID, &team1.Elo, &team2.Elo) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -367,7 +367,7 @@ func (r *mutationResolver) FinishMatch(ctx context.Context, winner string) (stri return "FinishFail", err } - tx, err := database.DB.Begin() + tx, err := r.DB.DB.Begin() if err != nil { return "FinishFail", err } @@ -440,7 +440,7 @@ func (r *mutationResolver) VetoMap(ctx context.Context, mapArg *string) (*model. UPDATE "Match" SET map = $1 WHERE status = 'ongoing'` - _, err := database.DB.Exec(query, voteMaps[0]) + _, err := r.DB.DB.Exec(query, voteMaps[0]) if err != nil { return nil, fmt.Errorf("failed to update match map: %w", err) } @@ -464,7 +464,7 @@ func (r *queryResolver) UserBySteam(ctx context.Context, steamID string) (*model SELECT id, teamspeak_id, Elo, admin, avatar, name FROM "User" WHERE steam_id = $1` - err := database.DB.QueryRow(query, steamID).Scan(&user.ID, + err := r.DB.DB.QueryRow(query, steamID).Scan(&user.ID, &user.TeamspeakID, &user.Elo, &user.Admin, &user.Avatar, &user.Name) if err != nil { return nil, database.CheckErrorCode(err) @@ -481,7 +481,7 @@ func (r *queryResolver) UserByTs(ctx context.Context, teamspeakID string) (*mode SELECT id, steam_id, Elo, admin, avatar, name FROM "User" WHERE teamspeak_id = $1` - err := database.DB.QueryRow(query, teamspeakID).Scan(&user.ID, + err := r.DB.DB.QueryRow(query, teamspeakID).Scan(&user.ID, &user.SteamID, &user.Elo, &user.Admin, &user.Avatar, &user.Name) if err != nil { return nil, database.CheckErrorCode(err) @@ -496,7 +496,7 @@ func (r *queryResolver) GetTeams(ctx context.Context, id *int) (*model.Teams, er var err error var rows *sql.Rows - tx, err := database.DB.Begin() + tx, err := r.DB.DB.Begin() if err != nil { return nil, database.CheckErrorCode(err) } diff --git a/server.go b/server.go index f6fc8be..d4e5a5a 100644 --- a/server.go +++ b/server.go @@ -23,7 +23,7 @@ func main() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - err := database.Open() + db, err := database.Open() if err != nil { log.Fatal(err) } @@ -36,10 +36,12 @@ func main() { log.Fatal(http.ListenAndServe(":"+port, nil)) }() - shutdown(<-sigs) + shutdown(<-sigs, db) } -func shutdown(sig os.Signal) { - database.Close() +func shutdown(sig os.Signal, db *database.DB) { + if err := db.DB.Close(); err != nil { + log.Fatalf("shutdown: %v", err) + } log.Println("API was shut down") } -- 2.44.0