Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(wip) Clean up ssh #189

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"fmt"
"io"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -304,7 +305,9 @@ func (c *Client) Connect(ctx context.Context) error {
err := retry.DoWithContext(ctx, func(ctx context.Context) error {
return c.connect(ctx)
}, retry.If(
func(err error) bool { return !errors.Is(err, protocol.ErrAbort) },
func(err error) bool {
return !errors.Is(err, protocol.ErrAbort) && !strings.Contains(err.Error(), "no supported methods")
},
))
if err != nil {
return fmt.Errorf("client connect: %w", err)
Expand Down
1 change: 1 addition & 0 deletions protocol/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@
ProcessStarter
WindowsChecker
}

Check failure on line 63 in protocol/connection.go

View workflow job for this annotation

GitHub Actions / Lint

File is not `gci`-ed with --skip-generated -s standard -s default (gci)
35 changes: 35 additions & 0 deletions protocol/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package protocol

import (
"fmt"
"net"
"strconv"
)

// Endpoint represents a network endpoint.
type Endpoint struct {
Address string `yaml:"address" validate:"required,hostname_rfc1123|ip"`
Port int `yaml:"port" validate:"gt=0,lte=65535"`
}

// TCPAddr returns the TCP address of the endpoint.
func (e *Endpoint) TCPAddr() (*net.TCPAddr, error) {
ip, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(e.Address, strconv.Itoa(e.Port)))
if err != nil {
return nil, fmt.Errorf("resolve address: %w", err)
}
return ip, nil
}

// Validate the endpoint.
func (e *Endpoint) Validate() error {
if e.Address == "" {
return fmt.Errorf("%w: address is required", ErrValidationFailed)
}

if e.Port <= 0 || e.Port > 65535 {
return fmt.Errorf("%w: port must be between 1 and 65535", ErrValidationFailed)
}

return nil
}
33 changes: 13 additions & 20 deletions protocol/openssh/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import (

// Config describes the configuration options for an OpenSSH connection.
type Config struct {
Address string `yaml:"address" validate:"required"`
protocol.Endpoint `yaml:",inline"`
User *string `yaml:"user"`
Port *int `yaml:"port"`
KeyPath *string `yaml:"keyPath,omitempty"`
ConfigPath *string `yaml:"configPath,omitempty"`
Options OptionArguments `yaml:"options,omitempty"`
Expand All @@ -27,34 +26,28 @@ func (c *Config) Connection() (protocol.Connection, error) {

// String returns a string representation of the configuration.
func (c *Config) String() string {
if c.Port == nil {
if c.Port == 0 {
return "openssh.Config{" + c.Address + "}"
}
return "openssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(*c.Port)) + "}"
return "openssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + "}"
}

// SetDefaults sets the default values for the configuration.
func (c *Config) SetDefaults() {
func (c *Config) SetDefaults() error {
if c.KeyPath != nil {
if path, err := homedir.Expand(*c.KeyPath); err == nil {
c.KeyPath = &path
path, err := homedir.Expand(*c.KeyPath)
if err != nil {
return fmt.Errorf("keypath: %w", err)
}
c.KeyPath = &path
}

if c.ConfigPath != nil {
if path, err := homedir.Expand(*c.ConfigPath); err == nil {
c.ConfigPath = &path
path, err := homedir.Expand(*c.ConfigPath)
if err != nil {
return fmt.Errorf("configpath: %w", err)
}
}
}

// Validate checks the configuration for any invalid values.
func (c *Config) Validate() error {
if c.Address == "" {
return fmt.Errorf("%w: address is required", protocol.ErrValidationFailed)
}

if c.Port != nil && (*c.Port <= 0 || *c.Port > 65535) {
return fmt.Errorf("%w: port must be between 1 and 65535", protocol.ErrValidationFailed)
c.ConfigPath = &path
}

return nil
Expand Down
12 changes: 7 additions & 5 deletions protocol/openssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ type Connection struct {

// NewConnection creates a new OpenSSH connection. Error is currently always nil.
func NewConnection(cfg Config) (*Connection, error) {
cfg.SetDefaults()
if err := cfg.SetDefaults(); err != nil {
return nil, fmt.Errorf("set openssh config defaults: %w", err)
}
return &Connection{Config: cfg}, nil
}

Expand Down Expand Up @@ -123,8 +125,8 @@ func (c *Connection) args() []string {
if c.KeyPath != nil && *c.KeyPath != "" {
args = append(args, "-i", *c.KeyPath)
}
if c.Port != nil {
args = append(args, "-p", strconv.Itoa(*c.Port))
if c.Port != 0 {
args = append(args, "-p", strconv.Itoa(c.Port))
}
if c.ConfigPath != nil && *c.ConfigPath != "" {
args = append(args, "-F", *c.ConfigPath)
Expand Down Expand Up @@ -250,10 +252,10 @@ func (c *Connection) String() string {
return c.name
}

if c.Port == nil {
if c.Port == 0 {
c.name = c.userhost()
} else {
c.name = fmt.Sprintf("%s:%d", c.userhost(), *c.Port)
c.name = fmt.Sprintf("%s:%d", c.userhost(), c.Port)
}

return c.name
Expand Down
82 changes: 63 additions & 19 deletions protocol/ssh/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/k0sproject/rig/v2/homedir"
"github.com/k0sproject/rig/v2/log"
"github.com/k0sproject/rig/v2/protocol"
"github.com/k0sproject/rig/v2/sshconfig"
ssh "golang.org/x/crypto/ssh"
)

Expand All @@ -17,11 +18,11 @@ type PasswordCallback func() (secret string, err error)
// Config describes an SSH connection's configuration.
type Config struct {
log.LoggerInjectable `yaml:"-"`
Address string `yaml:"address" validate:"required,hostname_rfc1123|ip"`
protocol.Endpoint `yaml:",inline"`
User string `yaml:"user" validate:"required" default:"root"`
Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"`
KeyPath *string `yaml:"keyPath" validate:"omitempty"`
Bastion *Config `yaml:"bastion,omitempty"`
ConfigPath string `yaml:"configPath,omitempty"`
PasswordCallback PasswordCallback `yaml:"-"`

// AuthMethods can be used to pass in a list of crypto/ssh.AuthMethod objects
Expand All @@ -30,45 +31,88 @@ type Config struct {
// For convenience, you can use ParseSSHPrivateKey() to parse a private key:
// authMethods, err := ssh.ParseSSHPrivateKey(key, rig.DefaultPassphraseCallback)
AuthMethods []ssh.AuthMethod `yaml:"-"`

sshconfig.Config `yaml:",inline"`

options *Options
}

// Connection returns a new Connection object based on the configuration.
func (c *Config) Connection() (protocol.Connection, error) {
conn, err := NewConnection(*c, WithLogger(c.Log()))
conn, err := NewConnection(*c, c.options.Funcs()...)
if !log.HasLogger(conn) && log.HasLogger(c) {
log.InjectLogger(c.Log(), c)
}
return conn, err
}

// String returns a string representation of the configuration.
func (c *Config) String() string {
return "ssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + "}"
return "ssh.Config{" + net.JoinHostPort(c.Address, strconv.Itoa(c.Endpoint.Port)) + "}"
}

// SetDefaults sets the default values for the configuration.
func (c *Config) SetDefaults() {
if c.Port == 0 {
c.Port = 22
func (c *Config) SetDefaults(opts ...Option) error {
options := NewOptions(opts...)

if c.KeyPath != nil {
path, err := homedir.Expand(*c.KeyPath)
if err != nil {
return fmt.Errorf("keypath: %w", err)
}
c.KeyPath = &path
c.IdentityFile = []string{path}
}
if c.User == "" {
c.User = "root"

c.Host = c.Address
if c.Endpoint.Port != 0 {
c.Config.Port = c.Endpoint.Port
}
if c.KeyPath != nil {
if path, err := homedir.Expand(*c.KeyPath); err == nil {
c.KeyPath = &path

if c.User != "" {
c.Config.User = c.User
}

var parser ConfigParser
if options.ConfigParser != nil {
parser = options.ConfigParser
} else {
p, err := ParserCache().Get(c.ConfigPath)
if err != nil {
return fmt.Errorf("get ssh config parser: %w", err)
}
parser = p
}

if err := parser.Apply(c, c.Address); err != nil {
return fmt.Errorf("apply values from ssh config: %w", err)
}

c.Endpoint.Port = c.Config.Port

if c.Config.User != "" {
c.User = c.Config.User
}

if c.Config.Hostname != "" {
c.Address = c.Config.Hostname
} else {
c.Address = c.Host
}

if c.Bastion != nil {
c.Bastion.SetDefaults()
if err := c.Bastion.SetDefaults(); err != nil {
return fmt.Errorf("bastion: %w", err)
}
}

return nil
}

// Validate returns an error if the configuration is invalid.
func (c *Config) Validate() error {
if c.Address == "" {
return fmt.Errorf("%w: address is required", protocol.ErrValidationFailed)
}

if c.Port <= 0 || c.Port > 65535 {
return fmt.Errorf("%w: port must be between 1 and 65535", protocol.ErrValidationFailed)
if err := c.Endpoint.Validate(); err != nil {
return fmt.Errorf("endpoint: %w", err)
}

if c.KeyPath != nil {
Expand Down
Loading
Loading