From 13dae00f8a44e001eb47460222c1af5cf8f6b966 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Fri, 5 Apr 2024 09:39:23 +0300 Subject: [PATCH 1/7] wip Signed-off-by: Kimmo Lehto --- protocol/connection.go | 5 +++++ protocol/ssh/config.go | 23 ++++++++++++++++------- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/protocol/connection.go b/protocol/connection.go index 28980dd1..7672f07b 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -60,3 +60,8 @@ type Connection interface { ProcessStarter WindowsChecker } + +// DefaultsSetter has a SetDefaults method +type DefaultsSetter interface { + SetDefaults() error +} diff --git a/protocol/ssh/config.go b/protocol/ssh/config.go index d7ce23cc..13025282 100644 --- a/protocol/ssh/config.go +++ b/protocol/ssh/config.go @@ -8,6 +8,7 @@ import ( "github.com/k0sproject/rig/v2/homedir" "github.com/k0sproject/rig/v2/log" "github.com/k0sproject/rig/v2/protocol" + "github.com/k0sproject/rig/v2/sshconfig" ssh "golang.org/x/crypto/ssh" ) @@ -22,6 +23,7 @@ type Config struct { Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` KeyPath *string `yaml:"keyPath" validate:"omitempty"` Bastion *Config `yaml:"bastion,omitempty"` + ConfigPath string `yaml:"configPath,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` // AuthMethods can be used to pass in a list of crypto/ssh.AuthMethod objects @@ -30,6 +32,9 @@ type Config struct { // For convenience, you can use ParseSSHPrivateKey() to parse a private key: // authMethods, err := ssh.ParseSSHPrivateKey(key, rig.DefaultPassphraseCallback) AuthMethods []ssh.AuthMethod `yaml:"-"` + + sshconfig.Config `yaml:",inline"` + parser *sshconfig.Parser } // Connection returns a new Connection object based on the configuration. @@ -45,20 +50,24 @@ func (c *Config) String() string { // SetDefaults sets the default values for the configuration. func (c *Config) SetDefaults() { - if c.Port == 0 { - c.Port = 22 - } - if c.User == "" { - c.User = "root" - } if c.KeyPath != nil { if path, err := homedir.Expand(*c.KeyPath); err == nil { - c.KeyPath = &path + c.IdentityFile = []string{path} } } + c.Host = c.Address if c.Bastion != nil { c.Bastion.SetDefaults() } + /* + + TODO setdefaults needs to be able to return an error + if c.ConfigPath != "" { + cfgPath, err := homedir.Expand(c.ConfigPath) + + c.parser = sshconfig.NewParser(c.ConfigPath) + } + */ } // Validate returns an error if the configuration is invalid. From e2ecd4dc20639b1b5298cb1a560c02a9f575398a Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:50:52 +0300 Subject: [PATCH 2/7] Support embedded structs in ssh config parser Signed-off-by: Kimmo Lehto --- sshconfig/parser.go | 1 + sshconfig/set.go | 65 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/sshconfig/parser.go b/sshconfig/parser.go index caf55f8b..60e3918a 100644 --- a/sshconfig/parser.go +++ b/sshconfig/parser.go @@ -314,6 +314,7 @@ func (p *Parser) apply(setter *Setter) error { p.iter.Skip() } default: + log.Trace(context.Background(), "applying", "key", key, "values", values, "path", path, "row", row) if err := setter.Set(key, values...); err != nil { return fmt.Errorf("set %q: %w", key, err) } diff --git a/sshconfig/set.go b/sshconfig/set.go index 0fcdbc52..24fc0e7d 100644 --- a/sshconfig/set.go +++ b/sshconfig/set.go @@ -153,31 +153,67 @@ func (s *Setter) discoverFields() { sfcMu.Lock() defer sfcMu.Unlock() - var sfields map[string]reflect.StructField if sfCache == nil { sfCache = make(map[reflect.Type]map[string]reflect.StructField) } - sf, cached := sfCache[s.elem.Type()] - if cached { - sfields = sf + t := s.elem.Type() + if sfields, cached := sfCache[t]; cached { s.elemFields = make(map[string]reflect.Value) for k, v := range sfields { - s.elemFields[k] = s.elem.FieldByIndex(v.Index) + fieldVal := s.elem + for _, idx := range v.Index { + if fieldVal.Kind() == reflect.Ptr { + fieldVal = fieldVal.Elem() + } + fieldVal = fieldVal.Field(idx) + } + s.elemFields[k] = fieldVal } - return + } else { + sfields = make(map[string]reflect.StructField) + sfCache[t] = sfields + s.elemFields = make(map[string]reflect.Value) + + collectFields(t, sfields, s.elem, nil, s.elemFields) } +} - sfields = make(map[string]reflect.StructField) - sfCache[s.elem.Type()] = sfields +func collectFields(t reflect.Type, sfields map[string]reflect.StructField, v reflect.Value, indexPrefix []int, elemFields map[string]reflect.Value) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } - s.elemFields = make(map[string]reflect.Value) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + index := append(indexPrefix, i) + + fieldType := field.Type + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } - for i := 0; i < s.elem.NumField(); i++ { - field := s.elem.Field(i) - structField := s.elem.Type().Field(i) - s.elemFields[structField.Name] = field - sfields[structField.Name] = structField + fieldVal := v.Field(i) + if field.Anonymous && fieldType.Kind() == reflect.Struct { + if fieldVal.Kind() == reflect.Ptr && !fieldVal.IsNil() { + fieldVal = fieldVal.Elem() + } + collectFields(fieldType, sfields, fieldVal, index, elemFields) + } else { + sfields[field.Name] = reflect.StructField{ + Name: field.Name, + Type: fieldType, + Index: index, + Anonymous: field.Anonymous, + } + if fieldVal.IsValid() { + elemFields[field.Name] = fieldVal + } + } } } @@ -958,6 +994,7 @@ func (s *Setter) Set(key string, values ...string) error { if errors.Is(err, errFieldNotFound) || errors.Is(err, errInvalidField) { if !s.ErrorOnUnknownFields { + log.Trace(context.Background(), "ignoring unknown key because not in strict mode", "key", key) return nil } if !s.isInIgnoreUnknown(key) { From fb7cd48d7662a1a18b40c13f2ae513f0cb072765 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:51:45 +0300 Subject: [PATCH 3/7] Abort ssh connect retry on "no methods remaining" Signed-off-by: Kimmo Lehto --- client.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 5e5ca2b1..049615a9 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ import ( "errors" "fmt" "io" + "strings" "sync" "time" @@ -304,7 +305,9 @@ func (c *Client) Connect(ctx context.Context) error { err := retry.DoWithContext(ctx, func(ctx context.Context) error { return c.connect(ctx) }, retry.If( - func(err error) bool { return !errors.Is(err, protocol.ErrAbort) }, + func(err error) bool { + return !errors.Is(err, protocol.ErrAbort) && !strings.Contains(err.Error(), "no supported methods") + }, )) if err != nil { return fmt.Errorf("client connect: %w", err) From 228282677dec1b753b88f04121864193e90054f4 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:52:29 +0300 Subject: [PATCH 4/7] Improve SSH setup Signed-off-by: Kimmo Lehto --- protocol/connection.go | 4 -- protocol/openssh/config.go | 33 +++++------ protocol/openssh/connection.go | 12 ++-- protocol/ssh/config.go | 81 ++++++++++++++++++-------- protocol/ssh/connection.go | 102 +++++++++++++-------------------- protocol/ssh/options.go | 14 ++++- test/rig_test.go | 34 ++++------- 7 files changed, 138 insertions(+), 142 deletions(-) diff --git a/protocol/connection.go b/protocol/connection.go index 7672f07b..07b2cff3 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -61,7 +61,3 @@ type Connection interface { WindowsChecker } -// DefaultsSetter has a SetDefaults method -type DefaultsSetter interface { - SetDefaults() error -} diff --git a/protocol/openssh/config.go b/protocol/openssh/config.go index ed53d175..ba8e9d19 100644 --- a/protocol/openssh/config.go +++ b/protocol/openssh/config.go @@ -11,9 +11,8 @@ import ( // Config describes the configuration options for an OpenSSH connection. type Config struct { - Address string `yaml:"address" validate:"required"` + protocol.Endpoint `yaml:",inline"` User *string `yaml:"user"` - Port *int `yaml:"port"` KeyPath *string `yaml:"keyPath,omitempty"` ConfigPath *string `yaml:"configPath,omitempty"` Options OptionArguments `yaml:"options,omitempty"` @@ -27,34 +26,28 @@ func (c *Config) Connection() (protocol.Connection, error) { // String returns a string representation of the configuration. func (c *Config) String() string { - if c.Port == nil { + if c.Port == 0 { return "openssh.Config{" + c.Address + "}" } - return "openssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(*c.Port)) + "}" + return "openssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + "}" } // SetDefaults sets the default values for the configuration. -func (c *Config) SetDefaults() { +func (c *Config) SetDefaults() error { if c.KeyPath != nil { - if path, err := homedir.Expand(*c.KeyPath); err == nil { - c.KeyPath = &path + path, err := homedir.Expand(*c.KeyPath) + if err != nil { + return fmt.Errorf("keypath: %w", err) } + c.KeyPath = &path } + if c.ConfigPath != nil { - if path, err := homedir.Expand(*c.ConfigPath); err == nil { - c.ConfigPath = &path + path, err := homedir.Expand(*c.ConfigPath) + if err != nil { + return fmt.Errorf("configpath: %w", err) } - } -} - -// Validate checks the configuration for any invalid values. -func (c *Config) Validate() error { - if c.Address == "" { - return fmt.Errorf("%w: address is required", protocol.ErrValidationFailed) - } - - if c.Port != nil && (*c.Port <= 0 || *c.Port > 65535) { - return fmt.Errorf("%w: port must be between 1 and 65535", protocol.ErrValidationFailed) + c.ConfigPath = &path } return nil diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index b3a96870..1918362a 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -40,7 +40,9 @@ type Connection struct { // NewConnection creates a new OpenSSH connection. Error is currently always nil. func NewConnection(cfg Config) (*Connection, error) { - cfg.SetDefaults() + if err := cfg.SetDefaults(); err != nil { + return nil, fmt.Errorf("set openssh config defaults: %w", err) + } return &Connection{Config: cfg}, nil } @@ -123,8 +125,8 @@ func (c *Connection) args() []string { if c.KeyPath != nil && *c.KeyPath != "" { args = append(args, "-i", *c.KeyPath) } - if c.Port != nil { - args = append(args, "-p", strconv.Itoa(*c.Port)) + if c.Port != 0 { + args = append(args, "-p", strconv.Itoa(c.Port)) } if c.ConfigPath != nil && *c.ConfigPath != "" { args = append(args, "-F", *c.ConfigPath) @@ -250,10 +252,10 @@ func (c *Connection) String() string { return c.name } - if c.Port == nil { + if c.Port == 0 { c.name = c.userhost() } else { - c.name = fmt.Sprintf("%s:%d", c.userhost(), *c.Port) + c.name = fmt.Sprintf("%s:%d", c.userhost(), c.Port) } return c.name diff --git a/protocol/ssh/config.go b/protocol/ssh/config.go index 13025282..b438c757 100644 --- a/protocol/ssh/config.go +++ b/protocol/ssh/config.go @@ -18,12 +18,11 @@ type PasswordCallback func() (secret string, err error) // Config describes an SSH connection's configuration. type Config struct { log.LoggerInjectable `yaml:"-"` - Address string `yaml:"address" validate:"required,hostname_rfc1123|ip"` + protocol.Endpoint `yaml:",inline"` User string `yaml:"user" validate:"required" default:"root"` - Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` KeyPath *string `yaml:"keyPath" validate:"omitempty"` Bastion *Config `yaml:"bastion,omitempty"` - ConfigPath string `yaml:"configPath,omitempty"` + ConfigPath string `yaml:"configPath,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` // AuthMethods can be used to pass in a list of crypto/ssh.AuthMethod objects @@ -34,50 +33,86 @@ type Config struct { AuthMethods []ssh.AuthMethod `yaml:"-"` sshconfig.Config `yaml:",inline"` - parser *sshconfig.Parser + + options *Options } // Connection returns a new Connection object based on the configuration. func (c *Config) Connection() (protocol.Connection, error) { - conn, err := NewConnection(*c, WithLogger(c.Log())) + conn, err := NewConnection(*c, c.options.Funcs()...) + if !log.HasLogger(conn) && log.HasLogger(c) { + log.InjectLogger(c.Log(), c) + } return conn, err } // String returns a string representation of the configuration. func (c *Config) String() string { - return "ssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + "}" + return "ssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Endpoint.Port)) + "}" } // SetDefaults sets the default values for the configuration. -func (c *Config) SetDefaults() { +func (c *Config) SetDefaults(opts ...Option) error { + options := NewOptions(opts...) + if c.KeyPath != nil { - if path, err := homedir.Expand(*c.KeyPath); err == nil { - c.IdentityFile = []string{path} + path, err := homedir.Expand(*c.KeyPath) + if err != nil { + return fmt.Errorf("keypath: %w", err) } + c.KeyPath = &path + c.IdentityFile = []string{path} } + c.Host = c.Address - if c.Bastion != nil { - c.Bastion.SetDefaults() + if c.Endpoint.Port != 0 { + c.Config.Port = c.Endpoint.Port } - /* - TODO setdefaults needs to be able to return an error - if c.ConfigPath != "" { - cfgPath, err := homedir.Expand(c.ConfigPath) + if c.User != "" { + c.Config.User = c.User + } - c.parser = sshconfig.NewParser(c.ConfigPath) + var parser ConfigParser + if options.ConfigParser != nil { + parser = options.ConfigParser + } else { + p, err := ParserCache().Get(c.ConfigPath) + if err != nil { + return fmt.Errorf("get ssh config parser: %w", err) + } + parser = p } - */ + + if err := parser.Apply(c, c.Address); err != nil { + return fmt.Errorf("apply values from ssh config: %w", err) + } + + c.Endpoint.Port = c.Config.Port + + if c.Config.User != "" { + c.User = c.Config.User + } + + if c.Config.Hostname != "" { + c.Address = c.Config.Hostname + } else { + c.Address = c.Host + } + + if c.Bastion != nil { + if err := c.Bastion.SetDefaults(); err != nil { + return fmt.Errorf("bastion: %w", err) + } + } + + return nil } // Validate returns an error if the configuration is invalid. func (c *Config) Validate() error { - if c.Address == "" { - return fmt.Errorf("%w: address is required", protocol.ErrValidationFailed) - } - - if c.Port <= 0 || c.Port > 65535 { - return fmt.Errorf("%w: port must be between 1 and 65535", protocol.ErrValidationFailed) + if err := c.Endpoint.Validate(); err != nil { + return fmt.Errorf("endpoint: %w", err) } if c.KeyPath != nil { diff --git a/protocol/ssh/connection.go b/protocol/ssh/connection.go index c348f1b2..f043c2cb 100644 --- a/protocol/ssh/connection.go +++ b/protocol/ssh/connection.go @@ -20,7 +20,6 @@ import ( "github.com/k0sproject/rig/v2/protocol" "github.com/k0sproject/rig/v2/protocol/ssh/agent" "github.com/k0sproject/rig/v2/protocol/ssh/hostkey" - "github.com/k0sproject/rig/v2/sshconfig" ssh "golang.org/x/crypto/ssh" "golang.org/x/term" ) @@ -32,12 +31,9 @@ type Connection struct { log.LoggerInjectable `yaml:"-"` Config `yaml:",inline"` - sshConfig *sshconfig.Config - options *Options - alias string - name string + name string isWindows *bool once sync.Once @@ -53,30 +49,12 @@ type Connection struct { func NewConnection(cfg Config, opts ...Option) (*Connection, error) { options := NewOptions(opts...) options.InjectLoggerTo(cfg, log.KeyProtocol, "ssh-config") - cfg.SetDefaults() - - c := &Connection{Config: cfg, options: options} //nolint:varnamelen - options.InjectLoggerTo(c, log.KeyProtocol, "ssh") - c.sshConfig = &sshconfig.Config{ - User: c.Config.User, - Host: c.Config.Address, - } - - if c.Config.Port != 0 && c.Config.Port != 22 { - c.sshConfig.Port = c.Config.Port - } - - if c.Config.KeyPath != nil { - c.sshConfig.IdentityFile = []string{*c.Config.KeyPath} + if err := cfg.SetDefaults(opts...); err != nil { + return nil, fmt.Errorf("set ssh config defaults: %w", err) } - if ConfigParser != nil { - if err := ConfigParser.Apply(c.sshConfig, c.Config.Address); err != nil { - return nil, fmt.Errorf("failed to apply ssh config: %w", err) - } - } - - c.Config.Port = c.sshConfig.Port + c := &Connection{Config: cfg, options: options} + options.InjectLoggerTo(c, log.KeyProtocol, "ssh") return c, nil } @@ -85,22 +63,11 @@ var ( authMethodCache = sync.Map{} knownHostsMU sync.Mutex - globalOnce sync.Once // ErrChecksumMismatch is returned when the checksum of an uploaded file does not match expectation. ErrChecksumMismatch = errors.New("checksum mismatch") ) -// TODO make the parser initialization more elegant. -func init() { - globalOnce.Do(func() { - parser, err := sshconfig.NewParser(nil) - if err == nil { - ConfigParser = parser - } - }) -} - // Dial initiates a connection to the addr from the remote host. func (c *Connection) Dial(network, address string) (net.Conn, error) { conn, err := c.client.Dial(network, address) @@ -112,7 +79,7 @@ func (c *Connection) Dial(network, address string) (net.Conn, error) { func (c *Connection) keypathsFromConfig() []string { log.Trace(context.Background(), "trying to get a keyfile path from ssh config", log.KeyHost, c) - idf := slices.Compact(c.sshConfig.IdentityFile) + idf := slices.Compact(c.IdentityFile) if len(idf) > 0 { log.Trace(context.Background(), fmt.Sprintf("detected %d identity file paths from ssh config", len(idf)), log.KeyFile, idf) @@ -125,19 +92,13 @@ func (c *Connection) keypathsFromConfig() []string { // SetDefaults sets various default values. func (c *Connection) SetDefaults() { c.once.Do(func() { - c.Port = c.sshConfig.Port - - if c.sshConfig.Hostname != "" { - c.alias = c.Address - c.Address = c.sshConfig.Hostname - } - for _, p := range c.keypathsFromConfig() { expanded, err := homedir.ExpandFile(p) if err != nil { log.Trace(context.Background(), "expand and validate", log.KeyFile, p, log.KeyError, err) continue } + log.Trace(context.Background(), "using identity file", log.KeyFile, expanded) c.Log().Debug("using identity file", log.KeyFile, expanded) c.keyPaths = append(c.keyPaths, expanded) } @@ -163,13 +124,10 @@ func (c *Connection) IsConnected() bool { return err == nil } -// ConfigParser is an instance of rig/v2/sshconfig.Parser - it is exported here for weird design decisions made in rig v0.x and will be removed in rig v2 final. -var ConfigParser *sshconfig.Parser - // String returns the connection's printable name. func (c *Connection) String() string { if c.name == "" { - c.name = net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + c.name = net.JoinHostPort(c.Address, strconv.Itoa(c.Endpoint.Port)) } return c.name @@ -180,7 +138,7 @@ func (c *Connection) Disconnect() { if c.client == nil { return } - if c.options.KeepAliveInterval != nil { + if c.Config.ServerAliveInterval != 0 { close(c.done) } c.client.Close() @@ -230,7 +188,7 @@ func knownhostsCallback(path string, permissive, hash bool) (ssh.HostKeyCallback } func isPermissive(c *Connection) bool { - if c.sshConfig.StrictHostKeyChecking.IsFalse() { + if c.StrictHostKeyChecking.IsFalse() { log.Trace(context.Background(), "config StrictHostkeyChecking is set to 'no'", log.KeyHost, c) return true } @@ -239,7 +197,7 @@ func isPermissive(c *Connection) bool { } func shouldHash(c *Connection) bool { - if c.sshConfig.HashKnownHosts.IsTrue() { + if c.HashKnownHosts.IsTrue() { log.Trace(context.Background(), "config HashKnownHosts is set", log.KeyHost, c) return true } @@ -263,7 +221,7 @@ func (c *Connection) hostkeyCallback() (ssh.HostKeyCallback, error) { var khPath string - for _, f := range c.sshConfig.UserKnownHostsFile { + for _, f := range c.UserKnownHostsFile { log.Trace(context.Background(), "trying known_hosts file from ssh config", log.KeyHost, c, log.KeyFile, f) exp, err := homedir.Expand(f) if err == nil { @@ -277,10 +235,20 @@ func (c *Connection) hostkeyCallback() (ssh.HostKeyCallback, error) { return knownhostsCallback(khPath, permissive, hash) } - return nil, fmt.Errorf("%w: no known_hosts file found", protocol.ErrAbort) + if len(c.UserKnownHostsFile) > 0 { + khPath = c.UserKnownHostsFile[0] + log.Trace(context.Background(), "using new known_hosts file", log.KeyHost, c, log.KeyFile, khPath) + return knownhostsCallback(khPath, permissive, hash) + } + + khPath = os.ExpandEnv("$HOME/.ssh/known_hosts") + log.Trace(context.Background(), "using default known_hosts file", log.KeyHost, c, log.KeyFile, khPath) + return knownhostsCallback(khPath, permissive, hash) } func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop + log.Trace(context.Background(), "creating client config", log.HostAttr(c), "user", c.User) + config := &ssh.ClientConfig{ User: c.User, } @@ -329,12 +297,15 @@ func (c *Connection) clientConfig() (*ssh.ClientConfig, error) { //nolint:cyclop } continue } + log.Trace(context.Background(), "trying to get a signer for identity", log.KeyFile, keyPath) privateKeyAuth, err := c.pkeySigner(signers, keyPath) if err != nil { + log.Trace(context.Background(), "failed to get a signer for identity", log.KeyFile, keyPath, log.ErrorAttr(err)) c.Log().Debug("failed to obtain a signer for identity", log.KeyFile, keyPath, log.ErrorAttr(err)) // store the error so this key won't be loaded again authMethodCache.Store(keyPath, err) } else { + log.Trace(context.Background(), "using public key signer", log.KeyFile, keyPath) authMethodCache.Store(keyPath, privateKeyAuth) config.Auth = append(config.Auth, privateKeyAuth) } @@ -363,10 +334,12 @@ func (c *Connection) connectViaBastion(dst string, config *ssh.ClientConfig) err } return err } + log.Trace(context.Background(), "connecting bastion", log.HostAttr(c), "destination", dst) bconn, err := bastionSSH.Dial("tcp", dst) if err != nil { return fmt.Errorf("bastion dial: %w", err) } + log.Trace(context.Background(), "creating client connection through bastion", log.HostAttr(c), "destination", dst) client, chans, reqs, err := ssh.NewClientConn(bconn, dst, config) if err != nil { if errors.Is(err, hostkey.ErrHostKeyMismatch) { @@ -376,19 +349,18 @@ func (c *Connection) connectViaBastion(dst string, config *ssh.ClientConfig) err } c.client = ssh.NewClient(client, chans, reqs) - c.startKeepalive() + if c.Config.ServerAliveInterval != 0 { + c.startKeepalive() + } return nil } func (c *Connection) startKeepalive() { - if c.options.KeepAliveInterval == nil { - return - } - + log.Trace(context.Background(), "starting keepalive", log.HostAttr(c), "interval", c.Config.ServerAliveInterval) c.done = make(chan struct{}) go func() { - ticker := time.NewTicker(*c.options.KeepAliveInterval) + ticker := time.NewTicker(c.Config.ServerAliveInterval) defer ticker.Stop() for { select { @@ -413,14 +385,16 @@ func (c *Connection) Connect() error { return fmt.Errorf("%w: create config: %w", protocol.ErrAbort, err) } - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Endpoint.Port)) if c.Bastion != nil { return c.connectViaBastion(dst, config) } + log.Trace(context.Background(), "connecting directly", log.HostAttr(c), "destination", dst) clientDirect, err := ssh.Dial("tcp", dst, config) if err != nil { + log.Trace(context.Background(), "dial failed", log.HostAttr(c), "destination", dst, log.ErrorAttr(err)) if errors.Is(err, hostkey.ErrHostKeyMismatch) { return fmt.Errorf("%w: %w", protocol.ErrAbort, err) } @@ -428,7 +402,9 @@ func (c *Connection) Connect() error { } c.client = clientDirect - c.startKeepalive() + if c.Config.ServerAliveInterval != 0 { + c.startKeepalive() + } return nil } diff --git a/protocol/ssh/options.go b/protocol/ssh/options.go index b951a838..d413e2aa 100644 --- a/protocol/ssh/options.go +++ b/protocol/ssh/options.go @@ -10,6 +10,13 @@ import ( type Options struct { log.LoggerInjectable KeepAliveInterval *time.Duration + ConfigParser ConfigParser + + funcs []Option +} + +func (o *Options) Funcs() []Option { + return o.funcs } // Option is a function that sets some option on the Options struct. @@ -19,6 +26,7 @@ type Option func(*Options) func NewOptions(opts ...Option) *Options { o := &Options{} for _, opt := range opts { + o.funcs = append(o.funcs, opt) opt(o) } return o @@ -31,9 +39,9 @@ func WithLogger(l log.Logger) Option { } } -// WithKeepAlive sets the keep-alive interval option. -func WithKeepAlive(d time.Duration) Option { +// WithConfigParser sets a custom ssh configuration parser. +func WithConfigParser(p ConfigParser) Option { return func(o *Options) { - o.KeepAliveInterval = &d + o.ConfigParser = p } } diff --git a/test/rig_test.go b/test/rig_test.go index 90cd25bf..6631a6be 100644 --- a/test/rig_test.go +++ b/test/rig_test.go @@ -26,6 +26,7 @@ import ( "github.com/k0sproject/rig/v2/remotefs" "github.com/k0sproject/rig/v2/rigtest" "github.com/k0sproject/rig/v2/sshconfig" + "github.com/k0sproject/rig/v2/sshconfig/options" "github.com/k0sproject/rig/v2/stattime" "github.com/stretchr/testify/require" @@ -90,19 +91,6 @@ func TestMain(m *testing.M) { } } - if configPath != "" { - f, err := os.Open(configPath) - if err != nil { - panic(err) - } - defer f.Close() - parser, err := sshconfig.NewParser(f) - if err != nil { - panic(err) - } - ssh.ConfigParser = parser - } - // Run tests os.Exit(m.Run()) } @@ -150,14 +138,16 @@ type Host struct { *rig.Client } -func GetHost(t *testing.T, options ...rig.ClientOption) *Host { +func GetHost(t *testing.T, clientOptions ...rig.ClientOption) *Host { var client protocol.Connection + endpoint := protocol.Endpoint{Address: targetHost, Port: targetPort} switch proto { case "ssh": cfg := ssh.Config{ - Address: targetHost, - Port: targetPort, - User: username, + Endpoint: endpoint, + User: username, + ConfigPath: configPath, + Config: sshconfig.Config{StrictHostKeyChecking: options.StrictHostKeyCheckingOptionNo}, } if privateKey != "" { @@ -176,8 +166,7 @@ func GetHost(t *testing.T, options ...rig.ClientOption) *Host { client = sshclient case "winrm": cfg := winrm.Config{ - Address: targetHost, - Port: targetPort, + Endpoint: endpoint, User: username, UseHTTPS: useHTTPS, Insecure: true, @@ -190,12 +179,9 @@ func GetHost(t *testing.T, options ...rig.ClientOption) *Host { client, _ = localhost.NewConnection() case "openssh": cfg := openssh.Config{ - Address: targetHost, + Endpoint: endpoint, DisableMultiplexing: !enableMultiplex, } - if targetPort != 22 { - cfg.Port = &targetPort - } if keyPath != "" { cfg.KeyPath = &keyPath @@ -213,7 +199,7 @@ func GetHost(t *testing.T, options ...rig.ClientOption) *Host { panic("unknown protocol") } opts := []rig.ClientOption{rig.WithConnection(client), rig.WithLogger(slog.New(NewTestLogHandler(t)))} - opts = append(opts, options...) + opts = append(opts, clientOptions...) c, err := rig.NewClient(opts...) require.NoError(t, err) return &Host{Client: c} From fce53ff6a0a70ee685fbb1828f4ae953a9cd0717 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:52:53 +0300 Subject: [PATCH 5/7] Make config SetDefaults return an error on failure Signed-off-by: Kimmo Lehto --- protocol/winrm/config.go | 44 +++++++++++++++++++++++------------- protocol/winrm/connection.go | 4 +++- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/protocol/winrm/config.go b/protocol/winrm/config.go index 8a12e731..c59da22b 100644 --- a/protocol/winrm/config.go +++ b/protocol/winrm/config.go @@ -14,13 +14,12 @@ import ( // Config describes the configuration options for a WinRM connection. type Config struct { log.LoggerInjectable `yaml:"-"` - Address string `yaml:"address" validate:"required,hostname_rfc1123|ip"` - User string `yaml:"user" validate:"omitempty,gt=2" default:"Administrator"` - Port int `yaml:"port" default:"5985" validate:"gt=0,lte=65535"` + protocol.Endpoint `yaml:",inline"` + User string `yaml:"user" validate:"omitempty,gt=2"` Password string `yaml:"password,omitempty"` - UseHTTPS bool `yaml:"useHTTPS" default:"false"` - Insecure bool `yaml:"insecure" default:"false"` - UseNTLM bool `yaml:"useNTLM" default:"false"` + UseHTTPS bool `yaml:"useHTTPS"` + Insecure bool `yaml:"insecure"` + UseNTLM bool `yaml:"useNTLM"` CACertPath string `yaml:"caCertPath,omitempty" validate:"omitempty,file"` CertPath string `yaml:"certPath,omitempty" validate:"omitempty,file"` KeyPath string `yaml:"keyPath,omitempty" validate:"omitempty,file"` @@ -29,16 +28,31 @@ type Config struct { } // SetDefaults sets various default values. -func (c *Config) SetDefaults() { - if p, err := homedir.Expand(c.CACertPath); err == nil { +func (c *Config) SetDefaults() error { + if c.User == "" { + c.User = "Administrator" + } + if c.CACertPath != "" { + p, err := homedir.Expand(c.CACertPath) + if err != nil { + return fmt.Errorf("cacertpath: %w", err) + } c.CACertPath = p } - if p, err := homedir.Expand(c.CertPath); err == nil { + if c.CertPath != "" { + p, err := homedir.Expand(c.CertPath) + if err != nil { + return fmt.Errorf("certpath: %w", err) + } c.CertPath = p } - if p, err := homedir.Expand(c.KeyPath); err == nil { + if c.KeyPath != "" { + p, err := homedir.Expand(c.KeyPath) + if err != nil { + return fmt.Errorf("keypath: %w", err) + } c.KeyPath = p } @@ -53,16 +67,14 @@ func (c *Config) SetDefaults() { case 5986: c.UseHTTPS = true } + + return nil } // Validate checks the configuration for any invalid values. func (c *Config) Validate() error { - if c.Address == "" { - return fmt.Errorf("%w: address is required", protocol.ErrValidationFailed) - } - - if c.Port <= 0 || c.Port > 65535 { - return fmt.Errorf("%w: port must be between 1 and 65535", protocol.ErrValidationFailed) + if err := c.Endpoint.Validate(); err != nil { + return fmt.Errorf("endpoint: %w", err) } if c.Bastion != nil { diff --git a/protocol/winrm/connection.go b/protocol/winrm/connection.go index 2958a75e..d3297bff 100644 --- a/protocol/winrm/connection.go +++ b/protocol/winrm/connection.go @@ -46,7 +46,9 @@ type dialFunc func(network, addr string) (net.Conn, error) func NewConnection(cfg Config, opts ...Option) (*Connection, error) { options := NewOptions(opts...) options.InjectLoggerTo(cfg, log.KeyProtocol, "winrm-config") - cfg.SetDefaults() + if err := cfg.SetDefaults(); err != nil { + return nil, fmt.Errorf("set winrm config defaults: %w", err) + } c := &Connection{Config: cfg} options.InjectLoggerTo(c, log.KeyProtocol, "winrm") From 87b965cda381f805b411515b54e42ff52884d19f Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:53:06 +0300 Subject: [PATCH 6/7] Add protocol endpoints Signed-off-by: Kimmo Lehto --- protocol/endpoint.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 protocol/endpoint.go diff --git a/protocol/endpoint.go b/protocol/endpoint.go new file mode 100644 index 00000000..11d37080 --- /dev/null +++ b/protocol/endpoint.go @@ -0,0 +1,35 @@ +package protocol + +import ( + "fmt" + "net" + "strconv" +) + +// Endpoint represents a network endpoint. +type Endpoint struct { + Address string `yaml:"address" validate:"required,hostname_rfc1123|ip"` + Port int `yaml:"port" validate:"gt=0,lte=65535"` +} + +// TCPAddr returns the TCP address of the endpoint. +func (e *Endpoint) TCPAddr() (*net.TCPAddr, error) { + ip, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(e.Address, strconv.Itoa(e.Port))) + if err != nil { + return nil, fmt.Errorf("resolve address: %w", err) + } + return ip, nil +} + +// Validate the endpoint. +func (e *Endpoint) Validate() error { + if e.Address == "" { + return fmt.Errorf("%w: address is required", ErrValidationFailed) + } + + if e.Port <= 0 || e.Port > 65535 { + return fmt.Errorf("%w: port must be between 1 and 65535", ErrValidationFailed) + } + + return nil +} From b5a96fd50ff3f528617565cdfb95fe5497327e41 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Sat, 6 Apr 2024 00:53:30 +0300 Subject: [PATCH 7/7] Cache ssh config parser instances Signed-off-by: Kimmo Lehto --- protocol/ssh/parser.go | 89 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 protocol/ssh/parser.go diff --git a/protocol/ssh/parser.go b/protocol/ssh/parser.go new file mode 100644 index 00000000..dc0f8efa --- /dev/null +++ b/protocol/ssh/parser.go @@ -0,0 +1,89 @@ +package ssh + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/k0sproject/rig/v2/homedir" + "github.com/k0sproject/rig/v2/log" + "github.com/k0sproject/rig/v2/sshconfig" +) + +// ConfigParser is an interface for applying an SSH configuration to an object. +type ConfigParser interface { + Apply(obj any, hostalias string) error +} + +type parserGetter interface { + Get(path string) (ConfigParser, error) +} + +// ParserCache is a sync.OnceValue that creates a new parserCache for caching +// ssh configuration parsers to avoid re-parsing the same configs multiple times. +var ParserCache = sync.OnceValue(func() parserGetter { + return &parserCache{ + cache: make(map[string]*sshconfig.Parser), + errCache: make(map[string]error), + } +}) + +type parserCache struct { + sync.Mutex + cache map[string]*sshconfig.Parser + errCache map[string]error +} + +func (c *parserCache) Get(path string) (ConfigParser, error) { + c.Lock() + defer c.Unlock() + + if err, ok := c.errCache[path]; ok { + return nil, err + } + + if parser, ok := c.cache[path]; ok { + log.Trace(context.Background(), "ssh config parser cache hit", "path", path) + return parser, nil + } + log.Trace(context.Background(), "ssh config parser cache miss", "path", path) + + if path == "" { + log.Trace(context.Background(), "creating a default locations ssh config parser") + parser, err := sshconfig.NewParser(nil) + if err != nil { + err = fmt.Errorf("create ssh config parser using system paths: %w", err) + c.errCache[path] = err + return nil, err + } + c.cache[path] = parser + return parser, nil + } + + expanded, err := homedir.Expand(path) + if err != nil { + err = fmt.Errorf("expand ssh config path %q: %w", path, err) + c.errCache[path] = err + return nil, err + } + + f, err := os.Open(expanded) + if err != nil { + err = fmt.Errorf("open ssh config %q: %w", expanded, err) + c.errCache[path] = err + return nil, err + } + defer f.Close() + + log.Trace(context.Background(), "creating a ssh config parser", log.KeyFile, expanded) + parser, err := sshconfig.NewParser(f) + if err != nil { + err = fmt.Errorf("parse ssh config %q: %w", expanded, err) + c.errCache[path] = err + return nil, err + } + + c.cache[path] = parser + return parser, nil +}