-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
562 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# MultiQuery (mq) | ||
|
||
A simple cli tool written in Go to query multiple MySQL databases. | ||
|
||
Sometimes I have the need to query multiple MySQL databases with the same structure. So I built this simple tool to allow me to do just that. You can connect to a mysql host directly or via an SSH tunnel. | ||
|
||
The following query will find all the databases with the prefix `wp_` and run a select query on each database and aggregate the results. | ||
```bash | ||
mq --host=localhost --prefix=wp_ --query="SELECT * FROM wp_users" | ||
``` | ||
|
||
Use the help command to read about the other parameters that are supported | ||
|
||
```bash | ||
mq --help | ||
``` | ||
|
||
This tool should have no bugs but use it at your own risk. If you are concerned please review my code and feel free to fork the repository to make your own changes. Feel free to open a pull request if you would like to contribute improvements to the code. | ||
|
||
You can also use the `--threaded` option to run concurrent queries. | ||
|
||
SSH tunneling is supported if you specify an SSH host. | ||
|
||
This tool tries to read ssh config files and my.cnf files when possible. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package cmd | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
"os" | ||
"os/user" | ||
"path" | ||
) | ||
|
||
var ( | ||
threaded bool | ||
sshHost string | ||
sshPort string | ||
sshUser string | ||
sshPass string | ||
privkeyPath string | ||
defaultKeyPath string | ||
password []byte | ||
dbUser string | ||
dbPass string | ||
dbHost string | ||
dbPort string | ||
dbName string | ||
dbPrefix string | ||
dbIgnore string | ||
dbConf string | ||
dbQuery string | ||
) | ||
|
||
func Execute() { | ||
|
||
usr, err := user.Current() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
defaultKeyPath = path.Join(usr.HomeDir, ".ssh/id_rsa") | ||
|
||
rootCmd.PersistentFlags().StringVarP(&dbQuery, "query", "q", "", "Mysql query you would like to run") | ||
rootCmd.PersistentFlags().StringVarP(&dbUser, "user", "u", "", "Mysql user") | ||
rootCmd.PersistentFlags().StringVarP(&dbPass, "password", "p", "", "Mysql password") | ||
rootCmd.PersistentFlags().StringVar(&dbHost, "host", "", "Mysql host") | ||
rootCmd.PersistentFlags().StringVar(&dbPort, "port", "3306", "Mysql port") | ||
rootCmd.PersistentFlags().StringVarP(&dbName, "database", "d", "", "Mysql database") | ||
rootCmd.PersistentFlags().StringVar(&dbPrefix, "dbprefix", "", "Mysql prefix for matching multiple database names") | ||
rootCmd.PersistentFlags().StringVar(&dbIgnore, "dbignore", "", "Ignores any database containing this string") | ||
rootCmd.PersistentFlags().StringVarP(&dbConf, "conf", "c", "~/.my.cnf", "Mysql config file location") | ||
rootCmd.PersistentFlags().StringVar(&sshHost, "sshhost", "", "SSH host") | ||
rootCmd.PersistentFlags().StringVar(&sshPort, "sshport", "", "SSH port to connect to (default is 22)") | ||
rootCmd.PersistentFlags().StringVar(&sshUser, "sshuser", "root", "SSH user") | ||
rootCmd.PersistentFlags().StringVar(&sshPass, "sshpass", "", "SSH password") | ||
rootCmd.PersistentFlags().BoolVar(&threaded, "threaded", false, "Use threading to run queries in parallel") | ||
rootCmd.PersistentFlags().StringVar(&privkeyPath, "sshkey", defaultKeyPath, "Path to your SSH private key") | ||
|
||
if err := rootCmd.Execute(); err != nil { | ||
fmt.Println(err) | ||
os.Exit(1) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,300 @@ | ||
package cmd | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/x509" | ||
"database/sql" | ||
"encoding/pem" | ||
"fmt" | ||
"github.com/go-ini/ini" | ||
"github.com/go-sql-driver/mysql" | ||
"github.com/kevinburke/ssh_config" | ||
"github.com/spf13/cobra" | ||
"golang.org/x/crypto/ssh" | ||
"golang.org/x/crypto/ssh/agent" | ||
"golang.org/x/crypto/ssh/terminal" | ||
"io/ioutil" | ||
"log" | ||
"net" | ||
"os" | ||
"os/user" | ||
"path" | ||
"strings" | ||
"syscall" | ||
) | ||
|
||
var rootCmd = &cobra.Command{ | ||
Use: "mq", | ||
Short: "mq or MultiQuery is a cli tool to perform mysql queries on multiple databases", | ||
Long: "mq or MultiQuery is a cli tool to perform mysql queries on multiple databases\n\nCreated by Markus Tenghamn ([email protected])", | ||
Run: run, | ||
} | ||
|
||
type ViaSSHDialer struct { | ||
client *ssh.Client | ||
_ *context.Context | ||
} | ||
|
||
func (d *ViaSSHDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { | ||
return d.client.Dial("tcp", addr) | ||
} | ||
|
||
func run(cmd *cobra.Command, args []string) { | ||
if sshHost != "" { | ||
runOverSSH(runMysql, "mysql+tcp") | ||
} else { | ||
cfg, err := ini.Load(dbConf) | ||
if err == nil { | ||
loadMyCnf(cfg) | ||
} | ||
runMysql("tcp") | ||
} | ||
log.Println("done") | ||
} | ||
|
||
func runMysql(dbNet string) { | ||
var databasesToQuery []string | ||
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbNet, dbHost, dbPort, dbName)); err == nil { | ||
defer db.Close() | ||
if rows, err := db.Query("SHOW DATABASES"); err == nil { | ||
for rows.Next() { | ||
var databaseName string | ||
_ = rows.Scan(&databaseName) | ||
if dbPrefix != "" { | ||
if strings.HasPrefix(databaseName, dbPrefix) && !strings.Contains(databaseName, dbIgnore) { | ||
databasesToQuery = append(databasesToQuery, databaseName) | ||
} | ||
} else { | ||
databasesToQuery = append(databasesToQuery, databaseName) | ||
} | ||
|
||
} | ||
} else { | ||
log.Fatal(err) | ||
} | ||
} | ||
|
||
// Run the query | ||
if dbQuery != "" { | ||
queries := make(chan string) | ||
var executedDbs []string | ||
for _, databaseName := range databasesToQuery { | ||
if threaded { | ||
go executeThreadedQuery(dbNet, databaseName, queries) | ||
} else { | ||
executeQuery(dbNet, databaseName) | ||
} | ||
} | ||
for { | ||
msg := <-queries | ||
executedDbs = append(executedDbs, msg) | ||
if len(executedDbs) == len(databasesToQuery) { | ||
break | ||
} | ||
} | ||
} | ||
} | ||
|
||
func executeThreadedQuery(dbNet string, databaseName string, queries chan string) { | ||
executeQuery(dbNet, databaseName) | ||
queries <- databaseName | ||
} | ||
|
||
func executeQuery(dbNet string, databaseName string) { | ||
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbNet, dbHost, dbPort, databaseName)); err == nil { | ||
defer db.Close() | ||
if rows, err := db.Query(dbQuery); err == nil { | ||
cols, err := rows.Columns() | ||
if err != nil { | ||
fmt.Println("Failed to get columns", err) | ||
return | ||
} | ||
|
||
// Result is your slice string. | ||
rawResult := make([][]byte, len(cols)) | ||
result := make([]string, len(cols)) | ||
|
||
dest := make([]interface{}, len(cols)) // A temporary interface{} slice | ||
for i, _ := range rawResult { | ||
dest[i] = &rawResult[i] // Put pointers to each string in the interface slice | ||
} | ||
|
||
for rows.Next() { | ||
err = rows.Scan(dest...) | ||
if err != nil { | ||
fmt.Println("Failed to scan row", err) | ||
return | ||
} | ||
|
||
for i, raw := range rawResult { | ||
if raw == nil { | ||
result[i] = "\\N" | ||
} else { | ||
result[i] = string(raw) | ||
} | ||
} | ||
|
||
fmt.Printf("%s: %#v\n", databaseName, result) | ||
} | ||
} else { | ||
log.Fatal(err) | ||
} | ||
} | ||
|
||
} | ||
|
||
func runOverSSH(mysqlFunc func(dbNet string), dbNet string) { | ||
|
||
// Attempt to read sshEnabled config file | ||
sshHostName := ssh_config.Get(sshHost, "HostName") | ||
sshPort := ssh_config.Get(sshHost, "Port") | ||
sshIdentityFile := ssh_config.Get(sshHost, "IdentityFile") | ||
sshUser := ssh_config.Get(sshHost, "User") | ||
|
||
if sshHostName != "" { | ||
sshHost = sshHostName | ||
} | ||
if sshPort == "" { | ||
sshPort = sshPort | ||
} | ||
if sshIdentityFile != "~/.sshEnabled/identity" && privkeyPath == defaultKeyPath { | ||
privkeyPath = sshIdentityFile | ||
} | ||
if sshUser != "root" && sshUser != "" { | ||
sshUser = sshUser | ||
} | ||
|
||
if strings.Contains(privkeyPath, "~/") { | ||
usr, err := user.Current() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
privkeyPath = path.Join(usr.HomeDir, strings.Replace(privkeyPath, "~", "", 1)) | ||
} | ||
|
||
sshHost = fmt.Sprintf("%s:%s", sshHost, sshPort) | ||
|
||
key, err := ioutil.ReadFile(privkeyPath) | ||
if err != nil { | ||
log.Fatalf("Unable to read private key: %v", err) | ||
} | ||
|
||
signer, err := ssh.ParsePrivateKey(key) | ||
if err != nil { | ||
CheckPassword() | ||
der := decrypt(key, password) | ||
key, err := x509.ParsePKCS1PrivateKey(der) | ||
if err != nil { | ||
log.Fatalf("Unable to parse private key: %v", err) | ||
} | ||
signer, err = ssh.NewSignerFromKey(key) | ||
if err != nil { | ||
log.Fatalf("Unable to get signer from private key: %v", err) | ||
} | ||
|
||
} | ||
|
||
var agentClient agent.Agent | ||
// Establish a connection to the local ssh-agent | ||
if conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { | ||
defer conn.Close() | ||
|
||
// Create a new instance of the ssh agent | ||
agentClient = agent.NewClient(conn) | ||
} | ||
|
||
// The client configuration with configuration option to use the ssh-agent | ||
sshConfig := &ssh.ClientConfig{ | ||
User: sshUser, | ||
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, | ||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { | ||
return nil | ||
}, | ||
} | ||
|
||
// When the agentClient connection succeeded, add them as AuthMethod | ||
if agentClient != nil { | ||
sshConfig.Auth = append(sshConfig.Auth, ssh.PublicKeysCallback(agentClient.Signers)) | ||
} | ||
// When there's a non empty password add the password AuthMethod | ||
if sshPass != "" { | ||
sshConfig.Auth = append(sshConfig.Auth, ssh.PasswordCallback(func() (string, error) { | ||
return sshPass, nil | ||
})) | ||
} | ||
|
||
// Connect to the SSH Server | ||
if sshcon, err := ssh.Dial("tcp", sshHost, sshConfig); err == nil { | ||
defer sshcon.Close() | ||
|
||
session, _ := sshcon.NewSession() | ||
defer session.Close() | ||
|
||
var stdoutBuf bytes.Buffer | ||
session.Stdout = &stdoutBuf | ||
session.Run(fmt.Sprintf("cat %s", dbConf)) | ||
mycnfBytes := stdoutBuf.Bytes() | ||
|
||
if len(mycnfBytes) > 0 { | ||
cfg, err := ini.Load(mycnfBytes) | ||
if err != nil { | ||
panic(err) | ||
} | ||
loadMyCnf(cfg) | ||
} | ||
|
||
dialer := ViaSSHDialer{client: sshcon} | ||
// Now we register the ViaSSHDialer with the ssh connection as a parameter | ||
mysql.RegisterDialContext("mysql+tcp", dialer.Dial) | ||
// once we are connected to mysql over ssh we can run the mysql stuff as we normally would | ||
mysqlFunc(dbNet) | ||
} else { | ||
log.Fatal(err) | ||
} | ||
|
||
} | ||
|
||
func loadMyCnf(cfg *ini.File) { | ||
for _, s := range cfg.Sections() { | ||
if s.Key("host").String() != "" && s.Key("host").String() != dbHost { | ||
dbHost = s.Key("host").String() | ||
} | ||
if s.Key("port").String() != "" && s.Key("port").String() != dbPort { | ||
dbPort = s.Key("port").String() | ||
} | ||
if s.Key("dbname").String() != "" && s.Key("dbname").String() != dbName { | ||
dbName = s.Key("dbname").String() | ||
} | ||
if s.Key("user").String() != "" && s.Key("user").String() != dbUser { | ||
dbUser = s.Key("user").String() | ||
} | ||
if s.Key("password").String() != "" && s.Key("password").String() != dbPass { | ||
dbPass = s.Key("password").String() | ||
} | ||
} | ||
} | ||
|
||
func CheckPassword() { | ||
if password == nil { | ||
var err error | ||
fmt.Print("Enter ssh key Password: ") | ||
password, err = terminal.ReadPassword(int(syscall.Stdin)) | ||
if err != nil { | ||
log.Fatalf("Invalid password: %v", err) | ||
} | ||
} | ||
} | ||
|
||
func decrypt(key []byte, password []byte) []byte { | ||
block, rest := pem.Decode(key) | ||
if len(rest) > 0 { | ||
log.Fatalf("Extra data included in key") | ||
} | ||
der, err := x509.DecryptPEMBlock(block, password) | ||
if err != nil { | ||
log.Fatalf("Decrypt failed: %v", err) | ||
} | ||
return der | ||
} |
Oops, something went wrong.