diff --git a/command.go b/command.go index f216f7d..7273089 100644 --- a/command.go +++ b/command.go @@ -556,7 +556,7 @@ func loadCommand(writer *resp.Writer, _ []resp.RESP) { writer.WriteError(err) return } - writer.WriteBulkString("OK") + writer.WriteSString("OK") } func saveCommand(writer *resp.Writer, _ []resp.RESP) { @@ -564,7 +564,7 @@ func saveCommand(writer *resp.Writer, _ []resp.RESP) { writer.WriteError(err) return } - writer.WriteBulkString("OK") + writer.WriteSString("OK") } func evalCommand(writer *resp.Writer, args []resp.RESP) { diff --git a/command_test.go b/command_test.go index 8b2b558..94ee1af 100644 --- a/command_test.go +++ b/command_test.go @@ -488,11 +488,12 @@ func testCommand(t *testing.T, testType string, rdb *redis.Client, sleepFn func( } }) - t.Run("save", func(t *testing.T) { + t.Run("save-load", func(t *testing.T) { rdb.FlushDB(ctx) // set key - rdb.Set(ctx, "rdb-key1", 123, 0) - rdb.Set(ctx, "rdb-key2", 123, time.Minute) + rdb.Set(ctx, "rdb-key1", "123", 0) + rdb.Set(ctx, "rdb-key2", "234", time.Minute) + rdb.Set(ctx, "rdb-key3", "345", 1) rdb.Incr(ctx, "key-incr") rdb.HSet(ctx, "rdb-hash1", "k1", "v1", "k2", "v2") rdb.SAdd(ctx, "rdb-set1", "k1", "k2") @@ -502,38 +503,61 @@ func testCommand(t *testing.T, testType string, rdb *redis.Client, sleepFn func( rdb.SAdd(ctx, "rdb-set2", key) } rdb.RPush(ctx, "rdb-list1", "k1", "k2", "k3") + rdb.ZAdd(ctx, "rdb-zset1", + redis.Z{Score: 200, Member: "k2"}, + redis.Z{Score: 100, Member: "k1"}, + redis.Z{Score: 300, Member: "k3"}) - res, err := rdb.Save(context.Background()).Result() - ast.Nil(err) + res, _ := rdb.Save(context.Background()).Result() ast.Equal(res, "OK") - _, err = rdb.Do(ctx, "load").Result() + _, err := rdb.Do(ctx, "load").Result() ast.Nil(err) + + // valid + res, _ = rdb.Get(ctx, "rdb-key1").Result() + ast.Equal(res, "123") + res, _ = rdb.Get(ctx, "rdb-key2").Result() + ast.Equal(res, "234") + _, err = rdb.Get(ctx, "rdb-key3").Result() + ast.Equal(err, redis.Nil) + + res, _ = rdb.Get(ctx, "key-incr").Result() + ast.Equal(res, "1") + + resm, _ := rdb.HGetAll(ctx, "rdb-hash1").Result() + ast.Equal(resm, map[string]string{"k1": "v1", "k2": "v2"}) + + ress, _ := rdb.SMembers(ctx, "rdb-set1").Result() + ast.ElementsMatch(ress, []string{"k1", "k2"}) + + ress, _ = rdb.LRange(ctx, "rdb-list1", 0, -1).Result() + ast.Equal(ress, []string{"k1", "k2", "k3"}) + + resz, _ := rdb.ZPopMin(ctx, "rdb-zset1").Result() + ast.Equal(resz, []redis.Z{{ + Member: "k1", Score: 100, + }}) }) } - t.Run("closed", func(t *testing.T) { - err := rdb.Close() - ast.Nil(err) + t.Run("close", func(t *testing.T) { + ast.Nil(rdb.Close()) }) } func TestConfig(t *testing.T) { ast := assert.New(t) - cfg, _ := LoadConfig("config.json") ast.Equal(cfg.Port, 6379) - _, err := LoadConfig("not-exist.json") ast.NotNil(err) - _, err = LoadConfig("go.mod") ast.NotNil(err) } func TestReadableSize(t *testing.T) { ast := assert.New(t) - ast.Equal(readableSize(50), "50B") ast.Equal(readableSize(50*KB), "50.0KB") ast.Equal(readableSize(50*MB), "50.0MB") diff --git a/const.go b/const.go index c944d53..8ce8bdd 100644 --- a/const.go +++ b/const.go @@ -4,13 +4,14 @@ import ( "github.com/xgzlucario/rotom/internal/hash" "github.com/xgzlucario/rotom/internal/iface" "github.com/xgzlucario/rotom/internal/list" + "github.com/xgzlucario/rotom/internal/zset" ) type ObjectType byte const ( TypeUnknown ObjectType = iota - TypeString ObjectType = iota + TypeString TypeInteger TypeMap TypeZipMap @@ -32,4 +33,5 @@ var type2c = map[ObjectType]func() iface.Encoder{ TypeSet: func() iface.Encoder { return hash.NewSet() }, TypeZipSet: func() iface.Encoder { return hash.NewZipSet() }, TypeList: func() iface.Encoder { return list.New() }, + TypeZSet: func() iface.Encoder { return zset.New() }, } diff --git a/go.mod b/go.mod index 93cc5c1..dc6686c 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/yuin/gopher-lua v1.1.1 github.com/zyedidia/generic v1.2.1 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c - golang.org/x/sys v0.26.0 + golang.org/x/sys v0.27.0 ) require ( @@ -33,7 +33,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - golang.org/x/arch v0.11.0 // indirect + golang.org/x/arch v0.12.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 44c6f9b..7bd61af 100644 --- a/go.sum +++ b/go.sum @@ -78,16 +78,16 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zyedidia/generic v1.2.1 h1:Zv5KS/N2m0XZZiuLS82qheRG4X1o5gsWreGb0hR7XDc= github.com/zyedidia/generic v1.2.1/go.mod h1:ly2RBz4mnz1yeuVbQA/VFwGjK3mnHGRj1JuoG336Bis= -golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= -golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= +golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/list/list.go b/internal/list/list.go index f13ce26..0e02aae 100644 --- a/internal/list/list.go +++ b/internal/list/list.go @@ -1,10 +1,16 @@ package list import ( + "github.com/xgzlucario/rotom/internal/iface" "github.com/xgzlucario/rotom/internal/resp" "github.com/zyedidia/generic/list" ) +var ( + _ iface.Encoder = (*ListPack)(nil) + _ iface.Encoder = (*QuickList)(nil) +) + // +------------------------------ QuickList -----------------------------+ // | +-----------+ +-----------+ +-----------+ | // head --- | listpack0 | <-> | listpack1 | <-> ... <-> | listpackN | --- tail diff --git a/internal/resp/resp.go b/internal/resp/resp.go index 4b8d092..9a33279 100644 --- a/internal/resp/resp.go +++ b/internal/resp/resp.go @@ -115,6 +115,14 @@ func (r *Reader) ReadInteger() (int, error) { return r.readInteger() } +func (r *Reader) ReadFloat() (float64, error) { + buf, err := r.ReadBulk() + if err != nil { + return 0, nil + } + return strconv.ParseFloat(b2s(buf), 64) +} + func (r *Reader) ReadBulk() ([]byte, error) { if len(r.b) == 0 || r.b[0] != BULK { return nil, errors.New("command is not begin with BULK") diff --git a/internal/zset/zset.go b/internal/zset/zset.go index 7ced762..ea32dc9 100644 --- a/internal/zset/zset.go +++ b/internal/zset/zset.go @@ -2,11 +2,16 @@ package zset import ( "cmp" - "github.com/bytedance/sonic" + "github.com/xgzlucario/rotom/internal/iface" + "github.com/xgzlucario/rotom/internal/resp" "github.com/chen3feng/stl4go" ) +var ( + _ iface.Encoder = (*ZSet)(nil) +) + type node struct { key string score float64 @@ -99,17 +104,32 @@ func (z *ZSet) Len() int { return len(z.m) } -func (z *ZSet) Marshal() ([]byte, error) { - return sonic.Marshal(z.m) +func (z *ZSet) Encode(writer *resp.Writer) error { + writer.WriteArrayHead(z.Len()) + for k, s := range z.m { + writer.WriteBulkString(k) + writer.WriteFloat(s) + } + return nil } -func (z *ZSet) Unmarshal(src []byte) error { - err := sonic.Unmarshal(src, &z.m) +func (z *ZSet) Decode(reader *resp.Reader) error { + n, err := reader.ReadArrayHead() if err != nil { return err } - for k, v := range z.m { - z.skl.Insert(node{k, v}, struct{}{}) + for range n { + buf, err := reader.ReadBulk() + if err != nil { + return err + } + score, err := reader.ReadFloat() + if err != nil { + return err + } + key := string(buf) + z.skl.Insert(node{key, score}, struct{}{}) + z.m[key] = score } return nil } diff --git a/rdb.go b/rdb.go index f4bc853..6eee2a0 100644 --- a/rdb.go +++ b/rdb.go @@ -1,7 +1,6 @@ package main import ( - "errors" "fmt" "github.com/tidwall/mmap" "github.com/xgzlucario/rotom/internal/iface" @@ -40,23 +39,11 @@ func (r *Rdb) SaveDB() error { switch objectType { case TypeString: - raw, ok := v.([]byte) - if !ok { - return errors.New("invalid data typeString") - } - writer.WriteBulk(raw) + writer.WriteBulk(v.([]byte)) case TypeInteger: - raw, ok := v.(int) - if !ok { - return errors.New("invalid data typeInteger") - } - writer.WriteInteger(raw) + writer.WriteInteger(v.(int)) default: - val, ok := v.(iface.Encoder) - if !ok { - return errors.New("invalid data type") - } - if err := val.Encode(writer); err != nil { + if err = v.(iface.Encoder).Encode(writer); err != nil { return err } } @@ -75,15 +62,8 @@ func (r *Rdb) SaveDB() error { } func (r *Rdb) LoadDB() error { - fs, err := os.OpenFile(r.path, os.O_RDWR, 0644) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return err - } // Read file data by mmap. - data, err := mmap.MapFile(fs, false) + data, err := mmap.Open(r.path, false) if len(data) == 0 { return nil } @@ -92,12 +72,12 @@ func (r *Rdb) LoadDB() error { } reader := resp.NewReader(data) - count, err := reader.ReadArrayHead() + n, err := reader.ReadArrayHead() if err != nil { return err } - for range count { + for range n { // format: {objectType,ttl,key,value} objectType, err := reader.ReadInteger() if err != nil { @@ -129,10 +109,7 @@ func (r *Rdb) LoadDB() error { default: val := type2c[ObjectType(objectType)]() - if val == nil { - return errors.New("invalid data type") - } - if err := val.Decode(reader); err != nil { + if err = val.Decode(reader); err != nil { return err } db.dict.Set(string(key), val) diff --git a/rotom.go b/rotom.go index 30c7a3c..c0e7e18 100644 --- a/rotom.go +++ b/rotom.go @@ -116,14 +116,13 @@ func AcceptHandler(loop *AeLoop, fd int, _ interface{}) { log.Error().Msgf("accept err: %v", err) return } - // create client + log.Info().Msgf("accept new client fd: %d", fd) client := &Client{ fd: cfd, replyWriter: resp.NewWriter(WriteBufSize), queryBuf: make([]byte, QueryBufSize), argsBuf: make([]resp.RESP, 8), } - server.clients[cfd] = client loop.AddRead(cfd, ReadQueryFromClient, client) }