Skip to content

Commit

Permalink
refactored interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Freeman committed Oct 6, 2024
1 parent c31f8b6 commit 58ed061
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 58 deletions.
16 changes: 8 additions & 8 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@ import (
"gofr.dev/pkg/gofr"
)

type Context struct {
type CustomContext struct {
*gofr.Context
claims map[string]interface{}
}

// NewCustomContext creates a new Context.
func NewCustomContext(c *gofr.Context) *Context {
return &Context{
func NewCustomContext(c *gofr.Context) *CustomContext {
return &CustomContext{
Context: c,
claims: make(map[string]interface{}),
}
}

// SetClaim sets the claim value.
func (c *Context) SetClaim(key string, value interface{}) {
func (c *CustomContext) SetClaim(key string, value interface{}) {
c.claims[key] = value
}

// GetClaim returns the claim value.
func (c *Context) GetClaim(key string) (interface{}, bool) {
func (c *CustomContext) GetClaim(key string) (interface{}, bool) {
value, ok := c.claims[key]

return value, ok
}

// GetStringClaim returns the claim value as a string.
func (c *Context) GetStringClaim(key string) (string, bool) {
func (c *CustomContext) GetStringClaim(key string) (string, bool) {
value, ok := c.claims[key]
if !ok {
return "", false
Expand All @@ -43,7 +43,7 @@ func (c *Context) GetStringClaim(key string) (string, bool) {
}

// GetUUIDClaim returns the claim value as a UUID.
func (c *Context) GetUUIDClaim(key string) (uuid.UUID, bool) {
func (c *CustomContext) GetUUIDClaim(key string) (uuid.UUID, bool) {
value, ok := c.claims[key]
if !ok {
return uuid.UUID{}, false
Expand All @@ -55,7 +55,7 @@ func (c *Context) GetUUIDClaim(key string) (uuid.UUID, bool) {
}

// GetAPIKey returns the API key from the context.
func (c *Context) GetAPIKey() (string, bool) {
func (c *CustomContext) GetAPIKey() (string, bool) {
if apiKey, ok := c.Context.Request.Context().Value("APIKey").(string); ok {
return apiKey, true
}
Expand Down
18 changes: 12 additions & 6 deletions pkg/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@ package context
import (
"testing"

"github.com/stretchr/testify/require"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

// GofrContextWrapper is an interface that wraps the methods of gofr.Context that we use
// GofrContextWrapper is an interface that wraps the methods of gofr.Context that we use.
type GofrContextWrapper interface {
Param(key string) string
PathParam(key string) string
Bind(v interface{}) error
}

// MockGofrContextWrapper is a mock for GofrContextWrapper
// MockGofrContextWrapper is a mock for GofrContextWrapper.
type MockGofrContextWrapper struct {
mock.Mock
}
Expand All @@ -35,15 +37,15 @@ func (m *MockGofrContextWrapper) Bind(v interface{}) error {
return args.Error(0)
}

// ContextWrapper wraps our Context and uses GofrContextWrapper instead of *gofr.Context
// ContextWrapper wraps our Context and uses GofrContextWrapper instead of *gofr.Context.
type ContextWrapper struct {
*Context
*CustomContext
gofrCtx GofrContextWrapper
}

func NewContextWrapper(gofrCtx GofrContextWrapper) *ContextWrapper {
return &ContextWrapper{
Context: &Context{
CustomContext: &CustomContext{
Context: nil, // We're not setting this as we're using the wrapper
claims: make(map[string]interface{}),
},
Expand All @@ -60,6 +62,7 @@ func (c *ContextWrapper) GetAPIKey() (string, bool) {
if apiKey != "" {
return apiKey, true
}

return "", false
}

Expand Down Expand Up @@ -125,12 +128,14 @@ func TestGetAPIKey(t *testing.T) {

// Test case when APIKey is present
mockGofrCtx.On("Param", "APIKey").Return("test-api-key").Once()

apiKey, ok := customCtx.GetAPIKey()
assert.True(t, ok)
assert.Equal(t, "test-api-key", apiKey)

// Test case when APIKey is not present
mockGofrCtx.On("Param", "APIKey").Return("").Once()

apiKey, ok = customCtx.GetAPIKey()
assert.False(t, ok)
assert.Empty(t, apiKey)
Expand All @@ -149,6 +154,7 @@ func TestBind(t *testing.T) {
mockGofrCtx.On("Bind", &testStruct).Return(nil)

err := customCtx.Bind(&testStruct)
assert.NoError(t, err)

require.NoError(t, err)
mockGofrCtx.AssertExpectations(t)
}
4 changes: 2 additions & 2 deletions pkg/context/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"gofr.dev/pkg/gofr"
)

//go:generate mockgen -destination=mock_context.go -package=context github.com/carverauto/eventrunner/pkg/context Interface
//go:generate mockgen -destination=mock_context.go -package=context github.com/carverauto/eventrunner/pkg/context Context

type Interface interface {
type Context interface {
SetClaim(key string, value interface{})
GetClaim(key string) (interface{}, bool)
GetStringClaim(key string) (string, bool)
Expand Down
68 changes: 34 additions & 34 deletions pkg/context/mock_context.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/eventingest/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func NewHTTPServer(app *gofr.App, forwarder EventForwarder) *HTTPServer {
}

// HandleEvent handles an event request, it accepts a Context and returns an interface and an error.
func (s *HTTPServer) HandleEvent(cc *customctx.Interface) (interface{}, error) {
func (s *HTTPServer) HandleEvent(cc customctx.Context) (interface{}, error) {
tenantID, ok := cc.GetUUIDClaim("tenant_id")
if !ok {
return nil, NewAuthError("Missing tenant ID")
Expand All @@ -36,7 +36,7 @@ func (s *HTTPServer) HandleEvent(cc *customctx.Interface) (interface{}, error) {
return nil, NewProcessingError("Invalid request body")
}

if err := s.eventForwarder.ForwardEvent(cc.Context, tenantID, customerID, eventData); err != nil {
if err := s.eventForwarder.ForwardEvent(cc.Context(), tenantID, customerID, eventData); err != nil {
return nil, NewProcessingError("Failed to forward event")
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/eventingest/http_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
func TestHandleEvent(t *testing.T) {
tests := []struct {
name string
setupMocks func(*gomock.Controller) (customctx.Interface, EventForwarder)
setupMocks func(*gomock.Controller) (customctx.Context, EventForwarder)
expectedResult interface{}
expectedError error
}{
{
name: "Success",
setupMocks: func(ctrl *gomock.Controller) (customctx.Interface, EventForwarder) {
setupMocks: func(ctrl *gomock.Controller) (customctx.Context, EventForwarder) {
mockCtx := customctx.NewMockInterface(ctrl)
mockEF := NewMockEventForwarder(ctrl)

Expand All @@ -42,7 +42,7 @@ func TestHandleEvent(t *testing.T) {
},
{
name: "Missing tenant ID",
setupMocks: func(ctrl *gomock.Controller) (customctx.Interface, EventForwarder) {
setupMocks: func(ctrl *gomock.Controller) (customctx.Context, EventForwarder) {
mockCtx := customctx.NewMockInterface(ctrl)
mockEF := NewMockEventForwarder(ctrl)

Expand All @@ -55,7 +55,7 @@ func TestHandleEvent(t *testing.T) {
},
{
name: "Missing customer ID",
setupMocks: func(ctrl *gomock.Controller) (customctx.Interface, EventForwarder) {
setupMocks: func(ctrl *gomock.Controller) (customctx.Context, EventForwarder) {
mockCtx := customctx.NewMockInterface(ctrl)
mockEF := NewMockEventForwarder(ctrl)

Expand All @@ -69,7 +69,7 @@ func TestHandleEvent(t *testing.T) {
},
{
name: "Invalid request body",
setupMocks: func(ctrl *gomock.Controller) (customctx.Interface, EventForwarder) {
setupMocks: func(ctrl *gomock.Controller) (customctx.Context, EventForwarder) {
mockCtx := customctx.NewMockInterface(ctrl)
mockEF := NewMockEventForwarder(ctrl)

Expand All @@ -84,7 +84,7 @@ func TestHandleEvent(t *testing.T) {
},
{
name: "Forward event failure",
setupMocks: func(ctrl *gomock.Controller) (customctx.Interface, EventForwarder) {
setupMocks: func(ctrl *gomock.Controller) (customctx.Context, EventForwarder) {
mockCtx := customctx.NewMockInterface(ctrl)
mockEF := NewMockEventForwarder(ctrl)

Expand Down

0 comments on commit 58ed061

Please sign in to comment.