-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add httpproxy * move utils methods from http to util
- Loading branch information
1 parent
b201bc3
commit 0d2f3fa
Showing
4 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
// Copyright 2024 Canonical. | ||
|
||
package rpc | ||
|
||
import ( | ||
"context" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"net/url" | ||
|
||
"github.com/juju/zaputil" | ||
"github.com/juju/zaputil/zapctx" | ||
|
||
"github.com/canonical/jimm/v3/internal/dbmodel" | ||
) | ||
|
||
type httpOptions struct { | ||
TLSConfig *tls.Config | ||
URL url.URL | ||
} | ||
|
||
// ProxyHTTP proxies the request to the controller using the info contained in dbmodel.Controller. | ||
func ProxyHTTP(ctx context.Context, ctl *dbmodel.Controller, w http.ResponseWriter, req *http.Request) { | ||
var tlsConfig *tls.Config | ||
if ctl.CACertificate != "" { | ||
cp := x509.NewCertPool() | ||
ok := cp.AppendCertsFromPEM([]byte(ctl.CACertificate)) | ||
if !ok { | ||
zapctx.Warn(ctx, "no CA certificates added") | ||
} | ||
tlsConfig = &tls.Config{ | ||
RootCAs: cp, | ||
ServerName: ctl.TLSHostname, | ||
MinVersion: tls.VersionTLS12, | ||
} | ||
} | ||
|
||
if ctl.PublicAddress != "" { | ||
err := doRequest(ctx, w, req, httpOptions{ | ||
TLSConfig: tlsConfig, | ||
URL: createURLWithNewHost(*req.URL, ctl.PublicAddress), | ||
}) | ||
if err == nil { | ||
return | ||
} | ||
} | ||
for _, hps := range ctl.Addresses { | ||
for _, hp := range hps { | ||
err := doRequest(ctx, w, req, httpOptions{ | ||
TLSConfig: tlsConfig, | ||
URL: createURLWithNewHost(*req.URL, fmt.Sprintf("%s:%d", hp.Value, hp.Port)), | ||
}) | ||
if err == nil { | ||
return | ||
} else { | ||
zapctx.Error(ctx, "failed to proxy request: continue to next addr", zaputil.Error(err)) | ||
} | ||
} | ||
} | ||
|
||
zapctx.Error(ctx, "couldn't find a valid address for controller") | ||
http.Error(w, "Gateway timeout", http.StatusGatewayTimeout) | ||
} | ||
|
||
func doRequest(ctx context.Context, w http.ResponseWriter, req *http.Request, opt httpOptions) error { | ||
client := &http.Client{ | ||
Transport: &http.Transport{ | ||
TLSClientConfig: opt.TLSConfig, | ||
}, | ||
} | ||
req = req.Clone(ctx) | ||
req.RequestURI = "" | ||
req.URL = &opt.URL | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return err | ||
} | ||
defer resp.Body.Close() | ||
|
||
// copy headers | ||
for k, vv := range resp.Header { | ||
for _, v := range vv { | ||
w.Header().Add(k, v) | ||
} | ||
} | ||
w.WriteHeader(resp.StatusCode) | ||
// copy body | ||
_, err = io.Copy(w, resp.Body) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// createURLWithNewHost takes a url.URL as parameter and return a url.URL with new host set and https enforced. | ||
func createURLWithNewHost(reqUrl url.URL, host string) url.URL { | ||
reqUrl.Scheme = "https" | ||
reqUrl.Host = host | ||
return reqUrl | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright 2024 Canonical. | ||
|
||
package rpc_test | ||
|
||
import ( | ||
"context" | ||
"encoding/pem" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
|
||
qt "github.com/frankban/quicktest" | ||
"github.com/juju/juju/core/network" | ||
jujuparams "github.com/juju/juju/rpc/params" | ||
|
||
"github.com/canonical/jimm/v3/internal/dbmodel" | ||
"github.com/canonical/jimm/v3/internal/rpc" | ||
) | ||
|
||
func TestProxyHTTP(t *testing.T) { | ||
c := qt.New(t) | ||
ctx := context.Background() | ||
// we expect the controller to respond with TLS | ||
fakeController := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
if strings.HasSuffix(r.URL.String(), "unauth") { | ||
w.WriteHeader(401) | ||
return | ||
} | ||
_, err := w.Write([]byte("OK")) | ||
c.Assert(err, qt.IsNil) | ||
})) | ||
defer fakeController.Close() | ||
controller := dbmodel.Controller{} | ||
pemData := pem.EncodeToMemory(&pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: fakeController.Certificate().Raw, | ||
}) | ||
controller.CACertificate = string(pemData) | ||
|
||
tests := []struct { | ||
description string | ||
setup func() | ||
path string | ||
statusExpected int | ||
}{ | ||
{ | ||
description: "good", | ||
setup: func() { | ||
newURL, _ := url.Parse(fakeController.URL) | ||
controller.PublicAddress = newURL.Host | ||
}, | ||
statusExpected: http.StatusOK, | ||
}, | ||
{ | ||
description: "controller no public address, only addresses", | ||
setup: func() { | ||
hp, err := network.ParseMachineHostPort(fakeController.Listener.Addr().String()) | ||
c.Assert(err, qt.Equals, nil) | ||
controller.Addresses = append(make([][]jujuparams.HostPort, 0), []jujuparams.HostPort{{ | ||
Address: jujuparams.FromMachineAddress(hp.MachineAddress), | ||
Port: hp.Port(), | ||
}}) | ||
controller.Addresses = append(controller.Addresses, []jujuparams.HostPort{}) | ||
controller.PublicAddress = "" | ||
}, | ||
statusExpected: http.StatusOK, | ||
}, | ||
{ | ||
description: "controller no public address, only addresses", | ||
setup: func() { | ||
hp, err := network.ParseMachineHostPort(fakeController.Listener.Addr().String()) | ||
c.Assert(err, qt.Equals, nil) | ||
controller.Addresses = append(make([][]jujuparams.HostPort, 0), []jujuparams.HostPort{{ | ||
Address: jujuparams.FromMachineAddress(hp.MachineAddress), | ||
Port: hp.Port(), | ||
}}) | ||
controller.Addresses = append(controller.Addresses, []jujuparams.HostPort{}) | ||
controller.PublicAddress = "" | ||
}, | ||
statusExpected: http.StatusOK, | ||
}, | ||
{ | ||
description: "controller responds unauthorized", | ||
setup: func() { | ||
newURL, _ := url.Parse(fakeController.URL) | ||
controller.PublicAddress = newURL.Host | ||
}, | ||
path: "/unauth", | ||
statusExpected: http.StatusUnauthorized, | ||
}, | ||
{ | ||
description: "controller not reachable", | ||
setup: func() { | ||
controller.Addresses = nil | ||
controller.PublicAddress = "localhost-not-found:61213" | ||
}, | ||
statusExpected: http.StatusGatewayTimeout, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
test.setup() | ||
req, err := http.NewRequest("POST", test.path, nil) | ||
c.Assert(err, qt.IsNil) | ||
recorder := httptest.NewRecorder() | ||
rpc.ProxyHTTP(ctx, &controller, recorder, req) | ||
resp := recorder.Result() | ||
defer resp.Body.Close() | ||
c.Assert(resp.StatusCode, qt.Equals, test.statusExpected) | ||
} | ||
} |