Files
cloudpods/pkg/util/ssh/ssh.go
wanyaoqi d4046b6c36 Automated cherry pick of #23725: feat(baremetal-agent): Linux PCIE disk support mdadm soft raid (#23930)
* feat(baremetal-agent): Linux PCIE disk support mdadm soft raid

* fix(baremeta-agent): soft raid support intel imsm

* fix: update baremetal-agent base
2025-12-22 13:29:36 +08:00

334 lines
7.4 KiB
Go

// Copyright 2019 Yunion
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ssh
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
"yunion.io/x/log"
"yunion.io/x/pkg/errors"
)
const (
ErrBadConfig = errors.Error("bad config")
ErrNetwork = errors.Error("network error")
ErrProtocol = errors.Error("ssh protocol error")
)
type ClientConfig struct {
Username string
Password string
Host string
Port int
PrivateKey string
}
func parsePrivateKey(keyBuff string) (ssh.Signer, error) {
return ssh.ParsePrivateKey([]byte(keyBuff))
}
func (conf ClientConfig) ToSshConfig() (*ssh.ClientConfig, error) {
cliConfig := &ssh.ClientConfig{
User: conf.Username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 15 * time.Second,
}
auths := make([]ssh.AuthMethod, 0)
if conf.Password != "" {
auths = append(auths, ssh.Password(conf.Password))
}
if conf.PrivateKey != "" {
signer, err := parsePrivateKey(conf.PrivateKey)
if err != nil {
return nil, errors.Wrapf(ErrBadConfig, "parse private key: %v", err)
}
auths = append(auths, ssh.PublicKeys(signer))
}
cliConfig.Auth = auths
return cliConfig, nil
}
func (conf ClientConfig) Connect() (*ssh.Client, error) {
cliConfig, err := conf.ToSshConfig()
if err != nil {
return nil, err
}
addr := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
client, err := ssh.Dial("tcp", addr, cliConfig)
if err != nil {
return nil, err
}
return client, nil
}
func (conf ClientConfig) ConnectContext(ctx context.Context) (*ssh.Client, error) {
cliConfig, err := conf.ToSshConfig()
if err != nil {
return nil, err
}
addr := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
d := &net.Dialer{}
netconn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, errors.Wrapf(ErrNetwork, "tcp dial: %v", err)
}
sshconn, chans, reqs, err := ssh.NewClientConn(netconn, addr, cliConfig)
if err != nil {
netconn.Close()
return nil, errors.Wrap(ErrProtocol, err.Error())
}
sshc := ssh.NewClient(sshconn, chans, reqs)
return sshc, nil
}
type Client struct {
config ClientConfig
client *ssh.Client
}
func (conf ClientConfig) NewClient() (*Client, error) {
cli, err := conf.Connect()
if err != nil {
return nil, err
}
return &Client{
config: conf,
client: cli,
}, nil
}
func NewClient(
host string,
port int,
username string,
password string,
privateKey string,
) (*Client, error) {
config := &ClientConfig{
Host: host,
Port: port,
Username: username,
Password: password,
PrivateKey: privateKey,
}
return config.NewClient()
}
func (s *Client) GetConfig() ClientConfig {
return s.config
}
func (s *Client) RawRun(cmds ...string) ([]string, error) {
return s.run(false, cmds, nil, false)
}
func (s *Client) RunCmd(cmd string) ([]string, error) {
return s.Run(cmd)
}
func (s *Client) Run(cmds ...string) ([]string, error) {
return s.run(true, cmds, nil, false)
}
func (s *Client) RunWithInput(input io.Reader, cmds ...string) ([]string, error) {
return s.run(true, cmds, input, false)
}
// RunWithTTY request Pty before run command.
func (s *Client) RunWithTTY(cmds ...string) ([]string, error) {
return s.run(false, cmds, nil, true)
}
func (s *Client) run(parseOutput bool, cmds []string, input io.Reader, withPty bool) ([]string, error) {
ret := []string{}
for _, cmd := range cmds {
session, err := s.client.NewSession()
if err != nil {
return nil, err
}
defer session.Close()
if withPty {
modes := ssh.TerminalModes{
ssh.ECHO: 1, // enable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
if err := session.RequestPty("xterm", 24, 80, modes); err != nil {
return nil, errors.Wrap(err, "Setup TTY")
}
}
log.Debugf("Run command(%s@%s): %s", s.config.Username, s.config.Host, cmd)
var stdOut bytes.Buffer
var stdErr bytes.Buffer
session.Stdout = &stdOut
session.Stderr = &stdErr
session.Stdin = input
err = session.Run(cmd)
if err != nil {
var outputErr error
errMsg := stdErr.String()
if len(stdOut.String()) != 0 {
errMsg = fmt.Sprintf("%s %s", errMsg, stdOut.String())
}
outputErr = errors.Error(errMsg)
err = errors.Wrapf(outputErr, "%q error: %v, cmd error", cmd, err)
return nil, err
}
if parseOutput {
ret = append(ret, ParseOutput(stdOut.Bytes())...)
} else {
ret = append(ret, stdOut.String())
}
}
return ret, nil
}
func ParseOutput(output []byte) []string {
lines := make([]string, 0)
for _, line := range strings.Split(string(output), "\n") {
lines = append(lines, strings.TrimSpace(line))
}
return lines
}
func (s *Client) Close() {
s.client.Close()
}
func updateTermSize(session *ssh.Session, quit <-chan int) {
sigwinchCh := make(chan os.Signal, 1)
setsignal(sigwinchCh)
fd := int(os.Stdin.Fd())
width, height, err := terminal.GetSize(fd)
if err != nil {
log.Errorf("get terminal size: %v", err)
}
for {
select {
case <-quit:
return
case sigwinCh := <-sigwinchCh:
if sigwinCh == nil {
<-quit
return
}
termWidth, termHeight, err := terminal.GetSize(fd)
if err != nil {
log.Errorf("get terminal size: %v", err)
}
if termHeight == height && termWidth == width {
continue
}
err = session.WindowChange(termHeight, termWidth)
if err != nil {
log.Errorf("send window-change request: %v", err)
continue
}
width = termWidth
height = termHeight
}
}
}
func (s *Client) RunTerminal() error {
defer s.Close()
session, err := s.client.NewSession()
if err != nil {
return errors.Wrap(err, "open new session")
}
defer session.Close()
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return errors.Wrap(err, "make raw terminal")
}
defer terminal.Restore(fd, state)
w, h, err := terminal.GetSize(fd)
if err != nil {
return errors.Wrap(err, "get terminal size")
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
term := os.Getenv("TERM")
if term == "" {
term = "xterm-256color"
}
if err := session.RequestPty(term, h, w, modes); err != nil {
return errors.Wrap(err, "session xterm")
}
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
if err := session.Shell(); err != nil {
return errors.Wrap(err, "session shell")
}
quit := make(chan int)
go updateTermSize(session, quit)
if err := session.Wait(); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
switch e.ExitStatus() {
case 130:
quit <- 1
return nil
}
}
quit <- 1
return errors.Wrap(err, "ssh wait")
}
quit <- 1
return nil
}
func IsExitMissingError(err error) bool {
errStr := new(ssh.ExitMissingError).Error()
if strings.Contains(err.Error(), errStr) {
return true
}
return false
}