Skip to content

Commit

Permalink
Initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
uberswe committed May 2, 2020
1 parent 5fa551b commit e5192fe
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 0 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# MultiQuery (mq)

A simple cli tool written in Go to query multiple MySQL databases.

Sometimes I have the need to query multiple MySQL databases with the same structure. So I built this simple tool to allow me to do just that. You can connect to a mysql host directly or via an SSH tunnel.

The following query will find all the databases with the prefix `wp_` and run a select query on each database and aggregate the results.
```bash
mq --host=localhost --prefix=wp_ --query="SELECT * FROM wp_users"
```

Use the help command to read about the other parameters that are supported

```bash
mq --help
```

This tool should have no bugs but use it at your own risk. If you are concerned please review my code and feel free to fork the repository to make your own changes. Feel free to open a pull request if you would like to contribute improvements to the code.

You can also use the `--threaded` option to run concurrent queries.

SSH tunneling is supported if you specify an SSH host.

This tool tries to read ssh config files and my.cnf files when possible.
60 changes: 60 additions & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package cmd

import (
"fmt"
"log"
"os"
"os/user"
"path"
)

var (
threaded bool
sshHost string
sshPort string
sshUser string
sshPass string
privkeyPath string
defaultKeyPath string
password []byte
dbUser string
dbPass string
dbHost string
dbPort string
dbName string
dbPrefix string
dbIgnore string
dbConf string
dbQuery string
)

func Execute() {

usr, err := user.Current()
if err != nil {
log.Fatal(err)
}

defaultKeyPath = path.Join(usr.HomeDir, ".ssh/id_rsa")

rootCmd.PersistentFlags().StringVarP(&dbQuery, "query", "q", "", "Mysql query you would like to run")
rootCmd.PersistentFlags().StringVarP(&dbUser, "user", "u", "", "Mysql user")
rootCmd.PersistentFlags().StringVarP(&dbPass, "password", "p", "", "Mysql password")
rootCmd.PersistentFlags().StringVar(&dbHost, "host", "", "Mysql host")
rootCmd.PersistentFlags().StringVar(&dbPort, "port", "3306", "Mysql port")
rootCmd.PersistentFlags().StringVarP(&dbName, "database", "d", "", "Mysql database")
rootCmd.PersistentFlags().StringVar(&dbPrefix, "dbprefix", "", "Mysql prefix for matching multiple database names")
rootCmd.PersistentFlags().StringVar(&dbIgnore, "dbignore", "", "Ignores any database containing this string")
rootCmd.PersistentFlags().StringVarP(&dbConf, "conf", "c", "~/.my.cnf", "Mysql config file location")
rootCmd.PersistentFlags().StringVar(&sshHost, "sshhost", "", "SSH host")
rootCmd.PersistentFlags().StringVar(&sshPort, "sshport", "", "SSH port to connect to (default is 22)")
rootCmd.PersistentFlags().StringVar(&sshUser, "sshuser", "root", "SSH user")
rootCmd.PersistentFlags().StringVar(&sshPass, "sshpass", "", "SSH password")
rootCmd.PersistentFlags().BoolVar(&threaded, "threaded", false, "Use threading to run queries in parallel")
rootCmd.PersistentFlags().StringVar(&privkeyPath, "sshkey", defaultKeyPath, "Path to your SSH private key")

if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
300 changes: 300 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
package cmd

import (
"bytes"
"context"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"github.com/go-ini/ini"
"github.com/go-sql-driver/mysql"
"github.com/kevinburke/ssh_config"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
"io/ioutil"
"log"
"net"
"os"
"os/user"
"path"
"strings"
"syscall"
)

var rootCmd = &cobra.Command{
Use: "mq",
Short: "mq or MultiQuery is a cli tool to perform mysql queries on multiple databases",
Long: "mq or MultiQuery is a cli tool to perform mysql queries on multiple databases\n\nCreated by Markus Tenghamn ([email protected])",
Run: run,
}

type ViaSSHDialer struct {
client *ssh.Client
_ *context.Context
}

func (d *ViaSSHDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
return d.client.Dial("tcp", addr)
}

func run(cmd *cobra.Command, args []string) {
if sshHost != "" {
runOverSSH(runMysql, "mysql+tcp")
} else {
cfg, err := ini.Load(dbConf)
if err == nil {
loadMyCnf(cfg)
}
runMysql("tcp")
}
log.Println("done")
}

func runMysql(dbNet string) {
var databasesToQuery []string
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbNet, dbHost, dbPort, dbName)); err == nil {
defer db.Close()
if rows, err := db.Query("SHOW DATABASES"); err == nil {
for rows.Next() {
var databaseName string
_ = rows.Scan(&databaseName)
if dbPrefix != "" {
if strings.HasPrefix(databaseName, dbPrefix) && !strings.Contains(databaseName, dbIgnore) {
databasesToQuery = append(databasesToQuery, databaseName)
}
} else {
databasesToQuery = append(databasesToQuery, databaseName)
}

}
} else {
log.Fatal(err)
}
}

// Run the query
if dbQuery != "" {
queries := make(chan string)
var executedDbs []string
for _, databaseName := range databasesToQuery {
if threaded {
go executeThreadedQuery(dbNet, databaseName, queries)
} else {
executeQuery(dbNet, databaseName)
}
}
for {
msg := <-queries
executedDbs = append(executedDbs, msg)
if len(executedDbs) == len(databasesToQuery) {
break
}
}
}
}

func executeThreadedQuery(dbNet string, databaseName string, queries chan string) {
executeQuery(dbNet, databaseName)
queries <- databaseName
}

func executeQuery(dbNet string, databaseName string) {
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@%s(%s:%s)/%s?parseTime=true", dbUser, dbPass, dbNet, dbHost, dbPort, databaseName)); err == nil {
defer db.Close()
if rows, err := db.Query(dbQuery); err == nil {
cols, err := rows.Columns()
if err != nil {
fmt.Println("Failed to get columns", err)
return
}

// Result is your slice string.
rawResult := make([][]byte, len(cols))
result := make([]string, len(cols))

dest := make([]interface{}, len(cols)) // A temporary interface{} slice
for i, _ := range rawResult {
dest[i] = &rawResult[i] // Put pointers to each string in the interface slice
}

for rows.Next() {
err = rows.Scan(dest...)
if err != nil {
fmt.Println("Failed to scan row", err)
return
}

for i, raw := range rawResult {
if raw == nil {
result[i] = "\\N"
} else {
result[i] = string(raw)
}
}

fmt.Printf("%s: %#v\n", databaseName, result)
}
} else {
log.Fatal(err)
}
}

}

func runOverSSH(mysqlFunc func(dbNet string), dbNet string) {

// Attempt to read sshEnabled config file
sshHostName := ssh_config.Get(sshHost, "HostName")
sshPort := ssh_config.Get(sshHost, "Port")
sshIdentityFile := ssh_config.Get(sshHost, "IdentityFile")
sshUser := ssh_config.Get(sshHost, "User")

if sshHostName != "" {
sshHost = sshHostName
}
if sshPort == "" {
sshPort = sshPort
}
if sshIdentityFile != "~/.sshEnabled/identity" && privkeyPath == defaultKeyPath {
privkeyPath = sshIdentityFile
}
if sshUser != "root" && sshUser != "" {
sshUser = sshUser
}

if strings.Contains(privkeyPath, "~/") {
usr, err := user.Current()
if err != nil {
log.Fatal(err)
}

privkeyPath = path.Join(usr.HomeDir, strings.Replace(privkeyPath, "~", "", 1))
}

sshHost = fmt.Sprintf("%s:%s", sshHost, sshPort)

key, err := ioutil.ReadFile(privkeyPath)
if err != nil {
log.Fatalf("Unable to read private key: %v", err)
}

signer, err := ssh.ParsePrivateKey(key)
if err != nil {
CheckPassword()
der := decrypt(key, password)
key, err := x509.ParsePKCS1PrivateKey(der)
if err != nil {
log.Fatalf("Unable to parse private key: %v", err)
}
signer, err = ssh.NewSignerFromKey(key)
if err != nil {
log.Fatalf("Unable to get signer from private key: %v", err)
}

}

var agentClient agent.Agent
// Establish a connection to the local ssh-agent
if conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
defer conn.Close()

// Create a new instance of the ssh agent
agentClient = agent.NewClient(conn)
}

// The client configuration with configuration option to use the ssh-agent
sshConfig := &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}

// When the agentClient connection succeeded, add them as AuthMethod
if agentClient != nil {
sshConfig.Auth = append(sshConfig.Auth, ssh.PublicKeysCallback(agentClient.Signers))
}
// When there's a non empty password add the password AuthMethod
if sshPass != "" {
sshConfig.Auth = append(sshConfig.Auth, ssh.PasswordCallback(func() (string, error) {
return sshPass, nil
}))
}

// Connect to the SSH Server
if sshcon, err := ssh.Dial("tcp", sshHost, sshConfig); err == nil {
defer sshcon.Close()

session, _ := sshcon.NewSession()
defer session.Close()

var stdoutBuf bytes.Buffer
session.Stdout = &stdoutBuf
session.Run(fmt.Sprintf("cat %s", dbConf))
mycnfBytes := stdoutBuf.Bytes()

if len(mycnfBytes) > 0 {
cfg, err := ini.Load(mycnfBytes)
if err != nil {
panic(err)
}
loadMyCnf(cfg)
}

dialer := ViaSSHDialer{client: sshcon}
// Now we register the ViaSSHDialer with the ssh connection as a parameter
mysql.RegisterDialContext("mysql+tcp", dialer.Dial)
// once we are connected to mysql over ssh we can run the mysql stuff as we normally would
mysqlFunc(dbNet)
} else {
log.Fatal(err)
}

}

func loadMyCnf(cfg *ini.File) {
for _, s := range cfg.Sections() {
if s.Key("host").String() != "" && s.Key("host").String() != dbHost {
dbHost = s.Key("host").String()
}
if s.Key("port").String() != "" && s.Key("port").String() != dbPort {
dbPort = s.Key("port").String()
}
if s.Key("dbname").String() != "" && s.Key("dbname").String() != dbName {
dbName = s.Key("dbname").String()
}
if s.Key("user").String() != "" && s.Key("user").String() != dbUser {
dbUser = s.Key("user").String()
}
if s.Key("password").String() != "" && s.Key("password").String() != dbPass {
dbPass = s.Key("password").String()
}
}
}

func CheckPassword() {
if password == nil {
var err error
fmt.Print("Enter ssh key Password: ")
password, err = terminal.ReadPassword(int(syscall.Stdin))
if err != nil {
log.Fatalf("Invalid password: %v", err)
}
}
}

func decrypt(key []byte, password []byte) []byte {
block, rest := pem.Decode(key)
if len(rest) > 0 {
log.Fatalf("Extra data included in key")
}
der, err := x509.DecryptPEMBlock(block, password)
if err != nil {
log.Fatalf("Decrypt failed: %v", err)
}
return der
}
Loading

0 comments on commit e5192fe

Please sign in to comment.