Skip to content

Commit

Permalink
API endpoint and associated RPC changes to resolve VMs IP's (#188)
Browse files Browse the repository at this point in the history
* API endpoint and associated RPC changes to resolve VMs IP's

* Fix "Missing expected argument '<name>'" error when doing "tart set"

* Implement TestIPEndpoint() and IP() method in controller HTTP client
  • Loading branch information
edigaryev committed Jul 3, 2024
1 parent 8119b22 commit 76f192b
Show file tree
Hide file tree
Showing 18 changed files with 603 additions and 124 deletions.
32 changes: 32 additions & 0 deletions api/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,31 @@ paths:
description: VM resource with the given name doesn't exist
'503':
description: Failed to establish connection with the worker responsible for the specified VM
/vms/{name}/ip:
get:
summary: "Resolve the VM's IP address on the worker"
tags:
- vms
parameters:
- in: query
name: wait
description: Duration in seconds to wait for the VM to transition into "running" state if not already running.
schema:
type: integer
minimum: 0
maximum: 65535
required: false
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/IP'
'404':
description: VM resource with the given name doesn't exist
'503':
description: Failed to resolve the IP address on the worker responsible for the specified VM
components:
schemas:
Worker:
Expand Down Expand Up @@ -369,6 +394,13 @@ components:
type: object
items:
$ref: '#/components/schemas/Event'
IP:
title: Result of VM's IP resolution
type: object
properties:
ip:
type: string
description: The resolved IP address
Event:
title: Generic Resource Event
type: object
Expand Down
3 changes: 3 additions & 0 deletions internal/controller/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ func (controller *Controller) initAPI() *gin.Engine {
v1.GET("/vms/:name/port-forward", func(c *gin.Context) {
controller.portForwardVM(c).Respond(c)
})
v1.GET("/vms/:name/ip", func(c *gin.Context) {
controller.ip(c).Respond(c)
})
v1.DELETE("/vms/:name", func(c *gin.Context) {
controller.deleteVM(c).Respond(c)
})
Expand Down
69 changes: 69 additions & 0 deletions internal/controller/api_vms_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package controller

import (
"context"
"github.com/cirruslabs/orchard/internal/responder"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"net/http"
"strconv"
"time"
)

func (controller *Controller) ip(ctx *gin.Context) responder.Responder {
if responder := controller.authorize(ctx, v1.ServiceAccountRoleComputeWrite); responder != nil {
return responder
}

// Retrieve and parse path and query parameters
name := ctx.Param("name")

waitRaw := ctx.Query("wait")
wait, err := strconv.ParseUint(waitRaw, 10, 16)
if err != nil {
return responder.Code(http.StatusBadRequest)
}
waitContext, waitContextCancel := context.WithTimeout(ctx, time.Duration(wait)*time.Second)
defer waitContextCancel()

// Look-up the VM
vm, responderImpl := controller.waitForVM(waitContext, name)
if responderImpl != nil {
return responderImpl
}

// Send an IP resolution request and wait for the result
session := uuid.New().String()
boomerangConnCh, cancel := controller.ipRendezvous.Request(ctx, session)
defer cancel()

err = controller.workerNotifier.Notify(ctx, vm.Worker, &rpc.WatchInstruction{
Action: &rpc.WatchInstruction_ResolveIpAction{
ResolveIpAction: &rpc.WatchInstruction_ResolveIP{
Session: session,
VmUid: vm.UID,
},
},
})
if err != nil {
controller.logger.Warnf("failed to request VM's IP from the worker %s: %v",
vm.Worker, err)

return responder.Code(http.StatusServiceUnavailable)
}

select {
case ip := <-boomerangConnCh:
result := struct {
IP string `json:"ip"`
}{
IP: ip,
}

return responder.JSON(http.StatusOK, &result)
case <-ctx.Done():
return responder.Error(ctx.Err())
}
}
2 changes: 1 addition & 1 deletion internal/controller/api_vms_portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (controller *Controller) portForward(
) responder.Responder {
// Request and wait for a connection with a worker
session := uuid.New().String()
boomerangConnCh, cancel := controller.proxy.Request(ctx, session)
boomerangConnCh, cancel := controller.connRendezvous.Request(ctx, session)
defer cancel()

// send request to worker to initiate port-forwarding connection back to us
Expand Down
10 changes: 6 additions & 4 deletions internal/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/controller/notifier"
"github.com/cirruslabs/orchard/internal/controller/proxy"
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
"github.com/cirruslabs/orchard/internal/controller/scheduler"
"github.com/cirruslabs/orchard/internal/controller/sshserver"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
Expand Down Expand Up @@ -54,7 +54,8 @@ type Controller struct {
logger *zap.SugaredLogger
grpcServer *grpc.Server
workerNotifier *notifier.Notifier
proxy *proxy.Proxy
connRendezvous *rendezvous.Rendezvous[net.Conn]
ipRendezvous *rendezvous.Rendezvous[string]
enableSwaggerDocs bool
workerOfflineTimeout time.Duration
maxWorkersPerLicense uint
Expand All @@ -69,7 +70,8 @@ type Controller struct {

func New(opts ...Option) (*Controller, error) {
controller := &Controller{
proxy: proxy.NewProxy(),
connRendezvous: rendezvous.New[net.Conn](),
ipRendezvous: rendezvous.New[string](),
workerOfflineTimeout: 3 * time.Minute,
maxWorkersPerLicense: maxWorkersPerDefaultLicense,
}
Expand Down Expand Up @@ -125,7 +127,7 @@ func New(opts ...Option) (*Controller, error) {
// Instantiate the SSH server (if configured)
if controller.sshListenAddr != "" && controller.sshSigner != nil {
controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner,
store, controller.proxy, controller.workerNotifier, controller.sshNoClientAuth, controller.logger)
store, controller.connRendezvous, controller.workerNotifier, controller.sshNoClientAuth, controller.logger)
if err != nil {
return nil, err
}
Expand Down
49 changes: 0 additions & 49 deletions internal/controller/proxy/proxy.go

This file was deleted.

48 changes: 48 additions & 0 deletions internal/controller/rendezvous/rendezvous.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package rendezvous

import (
"context"
"errors"
"github.com/cirruslabs/orchard/internal/concurrentmap"
)

var ErrInvalidToken = errors.New("invalid rendezvous token")

type Rendezvous[T any] struct {
sessions *concurrentmap.ConcurrentMap[*TokenSlot[T]]
}

type TokenSlot[T any] struct {
ctx context.Context
ch chan T
}

func New[T any]() *Rendezvous[T] {
return &Rendezvous[T]{
sessions: concurrentmap.NewConcurrentMap[*TokenSlot[T]](),
}
}

func (rendezvous *Rendezvous[T]) Request(ctx context.Context, session string) (chan T, func()) {
tokenSlot := &TokenSlot[T]{
ctx: ctx,
ch: make(chan T),
}

rendezvous.sessions.Store(session, tokenSlot)

return tokenSlot.ch, func() {
rendezvous.sessions.Delete(session)
}
}

func (rendezvous *Rendezvous[T]) Respond(session string, conn T) (context.Context, error) {
tokenSlot, ok := rendezvous.sessions.Load(session)
if !ok {
return nil, ErrInvalidToken
}

tokenSlot.ch <- conn

return tokenSlot.ctx, nil
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package proxy_test
package rendezvous_test

import (
"context"
"github.com/cirruslabs/orchard/internal/controller/proxy"
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"net"
Expand All @@ -15,7 +15,7 @@ func TestProxy(t *testing.T) {

expectedConn, _ := net.Pipe()

proxy := proxy.NewProxy()
proxy := rendezvous.New[net.Conn]()

token := uuid.New().String()

Expand Down
24 changes: 22 additions & 2 deletions internal/controller/rpc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controller

import (
"context"
v1pkg "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -61,8 +62,8 @@ func (controller *Controller) PortForward(stream rpc.Controller_PortForwardServe
}),
}

// make proxy aware of the connection
proxyCtx, err := controller.proxy.Respond(sessionMetadataValue[0], conn)
// make connection rendezvous aware of the connection
proxyCtx, err := controller.connRendezvous.Respond(sessionMetadataValue[0], conn)
if err != nil {
return err
}
Expand All @@ -74,3 +75,22 @@ func (controller *Controller) PortForward(stream rpc.Controller_PortForwardServe
return stream.Context().Err()
}
}

func (controller *Controller) ResolveIP(ctx context.Context, request *rpc.ResolveIPResult) (*emptypb.Empty, error) {
if !controller.authorizeGRPC(ctx, v1pkg.ServiceAccountRoleComputeWrite) {
return nil, status.Errorf(codes.Unauthenticated, "auth failed")
}

sessionMetadataValue := metadata.ValueFromIncomingContext(ctx, rpc.MetadataWorkerPortForwardingSessionKey)
if len(sessionMetadataValue) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "no session in metadata")
}

// Respond with the resolved IP address
_, err := controller.ipRendezvous.Respond(sessionMetadataValue[0], request.Ip)
if err != nil {
return nil, err
}

return &emptypb.Empty{}, nil
}
10 changes: 5 additions & 5 deletions internal/controller/sshserver/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/controller/notifier"
proxypkg "github.com/cirruslabs/orchard/internal/controller/proxy"
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/proxy"
"github.com/cirruslabs/orchard/pkg/resource/v1"
Expand All @@ -31,7 +31,7 @@ type SSHServer struct {
listener net.Listener
serverConfig *ssh.ServerConfig
store storepkg.Store
proxy *proxypkg.Proxy
connRendezvous *rendezvous.Rendezvous[net.Conn]
workerNotifier *notifier.Notifier
logger *zap.SugaredLogger
}
Expand All @@ -40,14 +40,14 @@ func NewSSHServer(
address string,
signer ssh.Signer,
store storepkg.Store,
proxy *proxypkg.Proxy,
connRendezvous *rendezvous.Rendezvous[net.Conn],
workerNotifier *notifier.Notifier,
noClientAuth bool,
logger *zap.SugaredLogger,
) (*SSHServer, error) {
server := &SSHServer{
store: store,
proxy: proxy,
connRendezvous: connRendezvous,
workerNotifier: workerNotifier,
logger: logger,
}
Expand Down Expand Up @@ -232,7 +232,7 @@ func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.N
// The user wants to connect to an existing VM, request and wait
// for a connection with the worker before accepting the channel
session := uuid.New().String()
boomerangConnCh, cancel := server.proxy.Request(ctx, session)
boomerangConnCh, cancel := server.connRendezvous.Request(ctx, session)
defer cancel()

err = server.workerNotifier.Notify(ctx, vm.Worker, &rpc.WatchInstruction{
Expand Down
Loading

0 comments on commit 76f192b

Please sign in to comment.