diff options
Diffstat (limited to 'cmd')
-rw-r--r-- | cmd/puppeth/ssh.go | 49 | ||||
-rw-r--r-- | cmd/puppeth/wizard.go | 14 | ||||
-rw-r--r-- | cmd/puppeth/wizard_intro.go | 9 | ||||
-rw-r--r-- | cmd/puppeth/wizard_netstats.go | 4 | ||||
-rw-r--r-- | cmd/puppeth/wizard_network.go | 18 |
5 files changed, 75 insertions, 19 deletions
diff --git a/cmd/puppeth/ssh.go b/cmd/puppeth/ssh.go index 34fbc566d..93668945c 100644 --- a/cmd/puppeth/ssh.go +++ b/cmd/puppeth/ssh.go @@ -17,6 +17,8 @@ package main import ( + "bufio" + "bytes" "errors" "fmt" "io/ioutil" @@ -37,18 +39,26 @@ import ( type sshClient struct { server string // Server name or IP without port number address string // IP address of the remote server + pubkey []byte // RSA public key to authenticate the server client *ssh.Client logger log.Logger } // dial establishes an SSH connection to a remote node using the current user and -// the user's configured private RSA key. -func dial(server string) (*sshClient, error) { +// the user's configured private RSA key. If that fails, password authentication +// is fallen back to. The caller may override the login user via user@server:port. +func dial(server string, pubkey []byte) (*sshClient, error) { // Figure out a label for the server and a logger label := server if strings.Contains(label, ":") { label = label[:strings.Index(label, ":")] } + login := "" + if strings.Contains(server, "@") { + login = label[:strings.Index(label, "@")] + label = label[strings.Index(label, "@")+1:] + server = server[strings.Index(server, "@")+1:] + } logger := log.New("server", label) logger.Debug("Attempting to establish SSH connection") @@ -56,6 +66,9 @@ func dial(server string) (*sshClient, error) { if err != nil { return nil, err } + if login == "" { + login = user.Username + } // Configure the supported authentication methods (private key and password) var auths []ssh.AuthMethod @@ -71,7 +84,7 @@ func dial(server string) (*sshClient, error) { } } auths = append(auths, ssh.PasswordCallback(func() (string, error) { - fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", user.Username, server) + fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) blob, err := terminal.ReadPassword(int(syscall.Stdin)) fmt.Println() @@ -86,11 +99,36 @@ func dial(server string) (*sshClient, error) { return nil, errors.New("no IPs associated with domain") } // Try to dial in to the remote server - logger.Trace("Dialing remote SSH server", "user", user.Username, "key", path) + logger.Trace("Dialing remote SSH server", "user", login) if !strings.Contains(server, ":") { server += ":22" } - client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: user.Username, Auth: auths}) + keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + // If no public key is known for SSH, ask the user to confirm + if pubkey == nil { + fmt.Printf("The authenticity of host '%s (%s)' can't be established.\n", hostname, remote) + fmt.Printf("SSH key fingerprint is %s [MD5]\n", ssh.FingerprintLegacyMD5(key)) + fmt.Printf("Are you sure you want to continue connecting (yes/no)? ") + + text, err := bufio.NewReader(os.Stdin).ReadString('\n') + switch { + case err != nil: + return err + case strings.TrimSpace(text) == "yes": + pubkey = key.Marshal() + return nil + default: + return fmt.Errorf("unknown auth choice: %v", text) + } + } + // If a public key exists for this SSH server, check that it matches + if bytes.Compare(pubkey, key.Marshal()) == 0 { + return nil + } + // We have a mismatch, forbid connecting + return errors.New("ssh key mismatch, readd the machine to update") + } + client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) if err != nil { return nil, err } @@ -98,6 +136,7 @@ func dial(server string) (*sshClient, error) { c := &sshClient{ server: label, address: addr[0], + pubkey: pubkey, client: client, logger: logger, } diff --git a/cmd/puppeth/wizard.go b/cmd/puppeth/wizard.go index 92d7962a0..9687d5e0d 100644 --- a/cmd/puppeth/wizard.go +++ b/cmd/puppeth/wizard.go @@ -44,14 +44,24 @@ type config struct { bootLight []string // Bootnodes to always connect to by light nodes ethstats string // Ethstats settings to cache for node deploys - Servers []string `json:"servers,omitempty"` + Servers map[string][]byte `json:"servers,omitempty"` +} + +// servers retrieves an alphabetically sorted list of servers. +func (c config) servers() []string { + servers := make([]string, 0, len(c.Servers)) + for server := range c.Servers { + servers = append(servers, server) + } + sort.Strings(servers) + + return servers } // flush dumps the contents of config to disk. func (c config) flush() { os.MkdirAll(filepath.Dir(c.path), 0755) - sort.Strings(c.Servers) out, _ := json.MarshalIndent(c, "", " ") if err := ioutil.WriteFile(c.path, out, 0644); err != nil { log.Warn("Failed to save puppeth configs", "file", c.path, "err", err) diff --git a/cmd/puppeth/wizard_intro.go b/cmd/puppeth/wizard_intro.go index 46383bb54..c3eaf5324 100644 --- a/cmd/puppeth/wizard_intro.go +++ b/cmd/puppeth/wizard_intro.go @@ -31,7 +31,10 @@ import ( // makeWizard creates and returns a new puppeth wizard. func makeWizard(network string) *wizard { return &wizard{ - network: network, + network: network, + conf: config{ + Servers: make(map[string][]byte), + }, servers: make(map[string]*sshClient), services: make(map[string][]string), in: bufio.NewReader(os.Stdin), @@ -77,9 +80,9 @@ func (w *wizard) run() { } else if err := json.Unmarshal(blob, &w.conf); err != nil { log.Crit("Previous configuration corrupted", "path", w.conf.path, "err", err) } else { - for _, server := range w.conf.Servers { + for server, pubkey := range w.conf.Servers { log.Info("Dialing previously configured server", "server", server) - client, err := dial(server) + client, err := dial(server, pubkey) if err != nil { log.Error("Previous server unreachable", "server", server, "err", err) } diff --git a/cmd/puppeth/wizard_netstats.go b/cmd/puppeth/wizard_netstats.go index c2a933a55..1225abb75 100644 --- a/cmd/puppeth/wizard_netstats.go +++ b/cmd/puppeth/wizard_netstats.go @@ -41,14 +41,14 @@ func (w *wizard) networkStats(tips bool) { stats.SetHeader([]string{"Server", "IP", "Status", "Service", "Details"}) stats.SetColWidth(128) - for _, server := range w.conf.Servers { + for server, pubkey := range w.conf.Servers { client := w.servers[server] logger := log.New("server", server) logger.Info("Starting remote server health-check") // If the server is not connected, try to connect again if client == nil { - conn, err := dial(server) + conn, err := dial(server, pubkey) if err != nil { logger.Error("Failed to establish remote connection", "err", err) stats.Append([]string{server, "", err.Error(), "", ""}) diff --git a/cmd/puppeth/wizard_network.go b/cmd/puppeth/wizard_network.go index 001d4e5b4..0455e1ef3 100644 --- a/cmd/puppeth/wizard_network.go +++ b/cmd/puppeth/wizard_network.go @@ -28,7 +28,9 @@ import ( func (w *wizard) manageServers() { // List all the servers we can disconnect, along with an entry to connect a new one fmt.Println() - for i, server := range w.conf.Servers { + + servers := w.conf.servers() + for i, server := range servers { fmt.Printf(" %d. Disconnect %s\n", i+1, server) } fmt.Printf(" %d. Connect another server\n", len(w.conf.Servers)+1) @@ -40,14 +42,14 @@ func (w *wizard) manageServers() { } // If the user selected an existing server, drop it if choice <= len(w.conf.Servers) { - server := w.conf.Servers[choice-1] + server := servers[choice-1] client := w.servers[server] delete(w.servers, server) if client != nil { client.Close() } - w.conf.Servers = append(w.conf.Servers[:choice-1], w.conf.Servers[choice:]...) + delete(w.conf.Servers, server) w.conf.flush() log.Info("Disconnected existing server", "server", server) @@ -73,14 +75,14 @@ func (w *wizard) makeServer() string { // Read and fial the server to ensure docker is present input := w.readString() - client, err := dial(input) + client, err := dial(input, nil) if err != nil { log.Error("Server not ready for puppeth", "err", err) return "" } // All checks passed, start tracking the server w.servers[input] = client - w.conf.Servers = append(w.conf.Servers, input) + w.conf.Servers[input] = client.pubkey w.conf.flush() return input @@ -93,7 +95,9 @@ func (w *wizard) selectServer() string { // List the available server to the user and wait for a choice fmt.Println() fmt.Println("Which server do you want to interact with?") - for i, server := range w.conf.Servers { + + servers := w.conf.servers() + for i, server := range servers { fmt.Printf(" %d. %s\n", i+1, server) } fmt.Printf(" %d. Connect another server\n", len(w.conf.Servers)+1) @@ -105,7 +109,7 @@ func (w *wizard) selectServer() string { } // If the user requested connecting to a new server, go for it if choice <= len(w.conf.Servers) { - return w.conf.Servers[choice-1] + return servers[choice-1] } return w.makeServer() } |