diff --git a/database/database.go b/database/database.go index a852bb5..e419751 100644 --- a/database/database.go +++ b/database/database.go @@ -9,18 +9,6 @@ import ( _ "github.com/jinzhu/gorm/dialects/mysql" ) -// formatConnectStrings concatenates the data from the config file into a -// usable MySQL connection string. -func formatConnectString(c config.ConfigStruct) string { - return fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?parseTime=true", - c.MySQLUser, - c.MySQLPass, - c.MySQLHost, - c.MySQLPort, - c.MySQLDB, - ) -} - // OpenConnection does as its name dictates and opens a connection to the // MysqlHost listed in the config func OpenConnection() (db *gorm.DB, err error) { @@ -36,6 +24,8 @@ func OpenConnection() (db *gorm.DB, err error) { // InitDB initializes database tables. func InitDB(db *gorm.DB) { + db = assertOpenConnection(db) + db.CreateTable(&White_Torrent{}) db.CreateTable(&Torrent{}) db.CreateTable(&TrackerStats{}) @@ -44,11 +34,8 @@ func InitDB(db *gorm.DB) { // AddWhitelistedTorrent adds a torrent to the whitelist so that they may be // used by the tracker in the future. -func (t *White_Torrent) AddWhitelistedTorrent() bool { - db, err := OpenConnection() - if err != nil { - err = err - } +func (t *White_Torrent) AddWhitelistedTorrent(db *gorm.DB) bool { + db = assertOpenConnection(db) db.Create(t) return db.NewRecord(t) @@ -57,11 +44,9 @@ func (t *White_Torrent) AddWhitelistedTorrent() bool { // GetTorrent retrieves a torrent by its infoHash from the generic torrent // table in the database. Note: there's also a whitelisted torrent table // (`white_torrent`). -func GetTorrent(infoHash string) (t *Torrent, err error) { - db, err := OpenConnection() - if err != nil { - err = err - } +func GetTorrent(infoHash string) (db *gorm.DB, t *Torrent, err error) { + db = assertOpenConnection(db) + t = &Torrent{} db.Where("info_hash = ?", infoHash).Find(&t) @@ -70,11 +55,9 @@ func GetTorrent(infoHash string) (t *Torrent, err error) { } // GetWhitelistedTorrent Retrieves a single whitelisted torrent by its infoHash -func GetWhitelistedTorrent(infoHash string) (t *White_Torrent, err error) { - db, err := OpenConnection() - if err != nil { - err = err - } +func GetWhitelistedTorrent(infoHash string) (db *gorm.DB, t *White_Torrent, err error) { + db = assertOpenConnection(db) + t = &White_Torrent{} x := db.Where("info_hash = ?", infoHash).First(&t) @@ -86,11 +69,8 @@ func GetWhitelistedTorrent(infoHash string) (t *White_Torrent, err error) { } // UpdateStats Handles updating statistics relevant to our tracker. -func UpdateStats(uploaded uint64, downloaded uint64) { - db, err := OpenConnection() - if err != nil { - err = err - } +func UpdateStats(db *gorm.DB, uploaded uint64, downloaded uint64) { + db = assertOpenConnection(db) ts := &TrackerStats{} db.First(&ts) @@ -104,11 +84,8 @@ func UpdateStats(uploaded uint64, downloaded uint64) { } // UpdateTorrentStats Handles updating statistics relevant to our tracker. -func UpdateTorrentStats(seederDelta int64, leecherDelta int64) { - db, err := OpenConnection() - if err != nil { - err = err - } +func UpdateTorrentStats(db *gorm.DB, seederDelta int64, leecherDelta int64) { + db = assertOpenConnection(db) t := &Torrent{} db.First(&t) @@ -123,11 +100,8 @@ func UpdateTorrentStats(seederDelta int64, leecherDelta int64) { // UpdatePeerStats handles updating peer info like hits per ip, downloaded // amount, uploaded amounts. -func UpdatePeerStats(uploaded uint64, downloaded uint64, ip string) { - db, err := OpenConnection() - if err != nil { - err = err - } +func UpdatePeerStats(db *gorm.DB, uploaded uint64, downloaded uint64, ip string) { + db = assertOpenConnection(db) ps := &Peer_Stats{Ip: ip} db.First(&ps) @@ -142,11 +116,8 @@ func UpdatePeerStats(uploaded uint64, downloaded uint64, ip string) { // GetWhitelistedTorrents allows us to retrieve all of the white listed // torrents. Mostly used for populating the Redis KV storage with all of our // whitelisted torrents. -func GetWhitelistedTorrents() (x *sql.Rows, err error) { - db, err := OpenConnection() - if err != nil { - err = err - } +func GetWhitelistedTorrents(db *gorm.DB) (x *sql.Rows, err error) { + db = assertOpenConnection(db) x, err = db.Table("white_torrents").Rows() if err != nil { @@ -158,6 +129,35 @@ func GetWhitelistedTorrents() (x *sql.Rows, err error) { // ScrapeTorrent supports the Scrape convention func ScrapeTorrent(db *gorm.DB, infoHash string) (torrent *Torrent) { + db = assertOpenConnection(db) + db.Where("info_hash = ?", infoHash).First(&torrent) return } + +// formatConnectStrings concatenates the data from the config file into a +// usable MySQL connection string. +func formatConnectString(c config.ConfigStruct) string { + return fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?parseTime=true", + c.MySQLUser, + c.MySQLPass, + c.MySQLHost, + c.MySQLPort, + c.MySQLDB, + ) +} + +// assertOpenConnection handles asserting a connection passed into a sql +// function is open, not nil. If nil, we'll create a new connection. +func assertOpenConnection(db *gorm.DB) *gorm.DB { + var err error + + if db == nil { + db, err = OpenConnection() + if err != nil { + err = err + } + } + + return db +} diff --git a/reaper/reaper.go b/reaper/reaper.go index 132b980..e5c938c 100644 --- a/reaper/reaper.go +++ b/reaper/reaper.go @@ -83,7 +83,7 @@ func StartReapingScheduler(waitTime time.Duration) { addedBy := new(string) dateAdded := new(int64) - x, err := db.GetWhitelistedTorrents() + x, err := db.GetWhitelistedTorrents(nil) for x.Next() { x.Scan(infoHash, name, addedBy, dateAdded) r.CreateNewTorrentKey(nil, *infoHash) diff --git a/server/server.go b/server/server.go index 3a1a934..1384bb5 100644 --- a/server/server.go +++ b/server/server.go @@ -17,6 +17,7 @@ type applicationContext struct { config config.ConfigStruct trackerLevel int peerStoreClient peerStore.PeerStore + dbPool *gorm.DB } type scrapeData struct { @@ -64,21 +65,21 @@ func (app *applicationContext) worker(data *a.AnnounceData) []string { } func (app *applicationContext) handleStatsTracking(data *a.AnnounceData) { - db.UpdateStats(data.Uploaded, data.Downloaded) + db.UpdateStats(app.dbPool, data.Uploaded, data.Downloaded) if app.trackerLevel > a.RATIOLESS { - db.UpdatePeerStats(data.Uploaded, data.Downloaded, data.IP) + db.UpdatePeerStats(app.dbPool, data.Uploaded, data.Downloaded, data.IP) } if data.Event == "completed" { - db.UpdateTorrentStats(1, -1) + db.UpdateTorrentStats(app.dbPool, 1, -1) return } else if data.Left == 0 { // TODO(ian): Don't assume the peer is already in the DB - db.UpdateTorrentStats(1, -1) + db.UpdateTorrentStats(app.dbPool, 1, -1) return } else if data.Event == "started" { - db.UpdateTorrentStats(0, 1) + db.UpdateTorrentStats(app.dbPool, 0, 1) } } @@ -129,21 +130,19 @@ func (app *applicationContext) requestHandler(w http.ResponseWriter, req *http.R app.handleStatsTracking(data) } -func scrapeHandlerCurried(dbConn *gorm.DB) func(w http.ResponseWriter, req *http.Request) { - return func(w http.ResponseWriter, req *http.Request) { - query := req.URL.Query() +func (app *applicationContext) scrapeHandler(w http.ResponseWriter, req *http.Request) { + query := req.URL.Query() - infoHash := query.Get("InfoHash") - if infoHash == "" { - failMsg := fmt.Sprintf("Tracker does not support multiple entire DB scrapes.") - writeErrorResponse(w, failMsg) - } else { - torrentData := db.ScrapeTorrent(dbConn, infoHash) - writeResponse(w, formatScrapeResponse(torrentData)) - } - - return + infoHash := query.Get("InfoHash") + if infoHash == "" { + failMsg := fmt.Sprintf("Tracker does not support multiple entire DB scrapes.") + writeErrorResponse(w, failMsg) + } else { + torrentData := db.ScrapeTorrent(app.dbPool, infoHash) + writeResponse(w, formatScrapeResponse(torrentData)) } + + return } func writeErrorResponse(w http.ResponseWriter, failMsg string) { @@ -157,22 +156,21 @@ func writeResponse(w http.ResponseWriter, values string) { // RunServer spins up the server and muxes the routes. func RunServer() { + dbConn, err := db.OpenConnection() + if err != nil { + panic("Failed to open connection to remote database server.") + } + app := applicationContext{ config: config.LoadConfig(), trackerLevel: a.RATIOLESS, peerStoreClient: new(peerStore.RedisStore), + dbPool: dbConn, } mux := http.NewServeMux() - dbConn, err := db.OpenConnection() - if err != nil { - panic("Failed to open connection to remote database server.") - } - - scrapeHandler := scrapeHandlerCurried(dbConn) - mux.HandleFunc("/announce", app.requestHandler) - mux.HandleFunc("/scrape", scrapeHandler) + mux.HandleFunc("/scrape", app.scrapeHandler) http.ListenAndServe(":3000", mux) }