Skip to content

Commit

Permalink
TCP KeyMatch command
Browse files Browse the repository at this point in the history
The new `K` keys command can return arbitrary sized results, so a TCP mode was added to the server and client.

subtree has changed to red-black from AVL, due to being faster for inserts, and we do not necessarily need the perfect sorting of keys.

GC of namespaces will be a bit more consistent - a namespace will not be GC'd two passes in a row.
  • Loading branch information
ruffrey committed May 19, 2023
1 parent d88a317 commit c461266
Show file tree
Hide file tree
Showing 15 changed files with 517 additions and 186 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func main() {
var expireAfterSeconds int64 = 60
preSharedSecret := "supersecret"
s := server.NewServer(expireAfterSeconds, preSharedSecret)
err := s.Listen(3509)
err := s.Listen(3509, 3509)
if err != nil {
panic(err)
}
Expand Down
274 changes: 243 additions & 31 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package client

import (
"bytes"
"context"
"errors"
"fmt"
"github.com/mailsac/dracula/client/serverpool"
"github.com/mailsac/dracula/client/waitingmessage"
"github.com/mailsac/dracula/protocol"
"github.com/mailsac/dracula/server/rawmessage"
"io/ioutil"
"log"
"math/rand"
"net"
"os"
"strconv"
Expand All @@ -18,35 +22,64 @@ import (
)

var (
ErrInitNoServers = errors.New("missing dracula udp server list on client init!")
ErrMessageTimedOut = errors.New("timed out waiting for message response")
ErrClientAlreadyInit = errors.New("client already initialized")
ErrCountReturnBytesTooShort = errors.New("too few bytes returned in count callback")
ErrNoHealthyServers = errors.New("no healthy dracula servers")
ErrNoHealthyUDPServers = errors.New("no healthy dracula udp servers")
ErrNoHealthyTCPServers = errors.New("no healthy dracula tcp servers")
)

type Client struct {
// conn is this clients incoming listen connection
// conn is this client's incoming udp listen connection
conn *net.UDPConn
// pool is the list of servers it will communciate with
pool *serverpool.Pool
//remoteServer *net.UDPAddr
// udpPool is the list of remote udp dracula server
udpPool *serverpool.Pool
// tcpPool is the list of remote tcp dracula servers
tcpPool *sync.Pool
tcpPoolMap *sync.Map

tcpServerList []net.TCPAddr

messagesWaiting *waitingmessage.ResponseCache // byte is the expected response command type

messageIDCounter uint32
preSharedKey []byte

disposed bool
log *log.Logger
disposed bool
timeoutDuration time.Duration
log *log.Logger
}

func NewClient(remoteServerIPPortList string, timeout time.Duration, preSharedKey string) *Client {
// Config for the client
type Config struct {
RemoteUDPIPPortList string
RemoteTCPIPPortList string
Timeout time.Duration
PreSharedKey string
}

func NewClient(conf Config) *Client {
var servers []*net.UDPAddr
parts := strings.Split(remoteServerIPPortList, ",")
if len(parts) < 1 {
panic("missing dracula server list on client init!")
if conf.Timeout == 0 {
conf.Timeout = time.Second
}
client := &Client{
preSharedKey: []byte(conf.PreSharedKey),
messagesWaiting: waitingmessage.NewCache(conf.Timeout),
log: log.New(os.Stdout, "", 0),
tcpPoolMap: &sync.Map{},
timeoutDuration: conf.Timeout,
}
for _, ipPort := range parts {
p := strings.Split(ipPort, ":")

udpParts := strings.Split(strings.Trim(conf.RemoteUDPIPPortList, " "), ",")
tcpParts := strings.Split(strings.Trim(conf.RemoteTCPIPPortList, " "), ",")

for _, ipPort := range udpParts {
p := strings.Split(strings.Trim(ipPort, " "), ":")
if p[0] == "" {
continue
}
if len(p) != 2 {
panic(fmt.Errorf("bad <ip:port> dracula client init %s", ipPort))
}
Expand All @@ -59,15 +92,58 @@ func NewClient(remoteServerIPPortList string, timeout time.Duration, preSharedKe
Port: sport,
})
}
c := &Client{
preSharedKey: []byte(preSharedKey),
messagesWaiting: waitingmessage.NewCache(timeout),
log: log.New(os.Stdout, "", 0),
client.udpPool = serverpool.NewPool(client, servers)

// now parse tcp servers - not required
if tcpParts[0] != "" {
for _, ipPort := range tcpParts {
p := strings.Split(strings.Trim(ipPort, " "), ":")
if p[0] == "" {
continue
}
if len(p) != 2 {
panic(fmt.Errorf("bad <ip:port> dracula tcp client init %s", ipPort))
}
sport, err := strconv.Atoi(p[1])
if err != nil {
panic(fmt.Errorf("bad ip:<port> dracula tcp client init %s", ipPort))
}
client.tcpServerList = append(client.tcpServerList, net.TCPAddr{
IP: net.ParseIP(p[0]),
Port: sport,
})
}
}

if len(servers) == 0 && len(client.tcpServerList) == 0 {
panic(ErrInitNoServers)
}
c.pool = serverpool.NewPool(c, servers)

c.DebugDisable()
return c
// setup the pool
client.tcpPool = &sync.Pool{
New: func() interface{} {
// Create a new net.TCPConn object for each server.

if len(client.tcpServerList) < 1 {
return nil
}
const maxTries = 5
for i := 0; i < maxTries; i++ {
randServer := client.tcpServerList[rand.Intn(len(client.tcpServerList))]
conn, err := net.DialTCP("tcp", nil, &randServer)
if err != nil {
client.log.Println("Connection to tcp dracula failed", randServer.String(), err)
}
client.tcpPoolMap.Store(conn, true)
return conn
}

return nil
},
}

client.DebugDisable()
return client
}

func (c *Client) GetConn() *net.UDPConn {
Expand Down Expand Up @@ -105,8 +181,8 @@ func (c *Client) Listen(localUDPPort int) error {
go c.handleResponsesForever()
go c.handleTimeouts()

c.pool.Listen()
c.log.Printf("client created server pool %v\n", c.pool.ListServers())
c.udpPool.Listen()
c.log.Printf("client created server udpPool %v\n", c.udpPool.ListServers())

return nil
}
Expand All @@ -119,16 +195,24 @@ func (c *Client) Close() error {
c.disposed = true
c.messagesWaiting.Dispose()

if c.pool != nil {
c.pool.Dispose()
if c.udpPool != nil {
c.udpPool.Dispose()
}
if c.conn != nil {
err = c.conn.Close()
if err != nil {
return err
}
}
// TODO: close down tcp client server

c.tcpPoolMap.Range(func(key, value interface{}) bool {
conn := key.(*net.TCPConn)
if conn != nil {
conn.Close()
}
c.tcpPoolMap.Delete(key)
return true
})

return nil
}
Expand Down Expand Up @@ -223,6 +307,34 @@ func (c *Client) Count(namespace, entryKey string) (int, error) {
return int(output), err
}

// KeyMatch asks for the list of keys over TCP which match the pattern
func (c *Client) KeyMatch(namespace, keyPattern string) ([]string, error) {
messageID := c.makeMessageID()
var wg sync.WaitGroup
var output string
var err error
cb := func(b []byte, e error) {
defer wg.Done()

if e != nil {
err = e
return
}
output = string(b)
}
wg.Add(1)
// callback has been setup, now make the request
sendPacket := protocol.NewPacketFromParts(protocol.CmdTCPOnlyKeys, messageID, []byte(namespace), []byte(keyPattern), c.preSharedKey)
c.sendOrCallbackErr(sendPacket, cb)

wg.Wait() // wait for callback to be called
results := strings.Split(output, "\n")
if results[0] == "" {
results = []string{}
}
return results, err
}

// Healthcheck implements serverpool.Checker
func (c *Client) Healthcheck(specificServer *net.UDPAddr) error {
messageID := c.makeMessageID()
Expand All @@ -237,7 +349,7 @@ func (c *Client) Healthcheck(specificServer *net.UDPAddr) error {
wg.Add(1)
// callback has been setup, now make the request
p := protocol.NewPacketFromParts(protocol.CmdCount, messageID, []byte("server_healthcheck_"+specificServer.String()), []byte("check"), c.preSharedKey)
c._send(p, specificServer, cb)
c._sendUDP(p, specificServer, cb)

wg.Wait() // wait for callback to be called
return err
Expand Down Expand Up @@ -313,8 +425,8 @@ func (c *Client) Put(namespace, value string) error {
return err
}

func (c *Client) _send(packet *protocol.Packet, remoteServer *net.UDPAddr, cb waitingmessage.Callback) {
c.log.Println("client sending packet:", remoteServer, string(packet.Command), packet.MessageID, packet.NamespaceString(), packet.DataValueString())
func (c *Client) _sendUDP(packet *protocol.Packet, remoteServer *net.UDPAddr, cb waitingmessage.Callback) {
c.log.Println("client sending udp packet:", remoteServer, string(packet.Command), packet.MessageID, packet.NamespaceString(), packet.DataValueString())

b, err := packet.Bytes()
if err != nil {
Expand Down Expand Up @@ -344,12 +456,112 @@ func (c *Client) _send(packet *protocol.Packet, remoteServer *net.UDPAddr, cb wa

// ok
}
func (c *Client) _sendTCP(packet *protocol.Packet, cb waitingmessage.Callback) {
c.log.Println("client sending tcp packet:", string(packet.Command), packet.MessageID, packet.NamespaceString(), packet.DataValueString())

// Get a connection from the pool.
key := c.tcpPool.Get()
conn := key.(*net.TCPConn)
defer func() {
if conn == nil {
c.tcpPoolMap.Delete(key)
} else {
c.tcpPool.Put(conn)
}
}()
if conn == nil {
cb([]byte{}, ErrNoHealthyTCPServers)
return
}

// needs stop
packet.DataValue = append(packet.DataValue, protocol.StopSymbol...)

packetBuf, err := packet.Bytes()
if err != nil {
// probably bad packet
cb([]byte{}, err)
return
}

resChan := make(chan *rawmessage.RawMessage)
defer close(resChan)

var msgErr error
// handle timeout
_ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() // Release resources if operation completes before timeout
readOneMessage := func(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.log.Println("tcp connection timed out", packet.Command, packet.MessageID)
conn.Close()
packet.DataValue = []byte(ErrMessageTimedOut.Error())
errPacket, _ := packet.Bytes()
resChan <- &rawmessage.RawMessage{
Message: errPacket,
Remote: nil,
MaybeTcpClient: conn,
}
conn = nil
return
default:
msgErr = rawmessage.ReadOneTcpMessage(c.log, resChan, conn)
if msgErr == nil {
return
}
packet.DataValue = []byte(msgErr.Error())
errPacket, otherErr := packet.Bytes()
if otherErr != nil && err != protocol.ErrBadOutputSize {
c.log.Println("failed constructing error packet after bad tcp read message", msgErr, otherErr)
return
}
resChan <- &rawmessage.RawMessage{
Message: errPacket,
Remote: nil,
MaybeTcpClient: conn,
}
}
}
}

// we are now waiting for the response, so send the message
_, err = conn.Write(packetBuf)
if err != nil {
c.log.Println("client tcp write failed", err)
cb([]byte{}, err)
conn = nil
return
}
go readOneMessage(_ctx)

// block waiting for response
res := <-resChan

// ok

// tcp callbacks are direct, they don't go through waitingmessages on a separate
// port, but don't callback the entire packet
resPacket, err := protocol.ParsePacket(*protocol.PadRight(&res.Message, protocol.PacketSize))
if err != nil && err != protocol.ErrInvalidPacketSizeTooLarge {
c.log.Println("client tcp parse res packet failed", err, "|"+string(res.Message)+"|")
cb([]byte{}, err)
return
}
cb(bytes.TrimSpace(resPacket.DataValue), nil)
}

func (c *Client) sendOrCallbackErr(packet *protocol.Packet, cb waitingmessage.Callback) {
remoteServer := c.pool.Choose()
if protocol.IsTcpOnlyCmd(packet.Command) {
c._sendTCP(packet, cb)
return
}
remoteServer := c.udpPool.Choose()
if remoteServer == nil {
cb([]byte{}, ErrNoHealthyServers)
c.log.Println("No healthy udp servers")
cb([]byte{}, ErrNoHealthyUDPServers)
return
}
c._send(packet, remoteServer, cb)
c._sendUDP(packet, remoteServer, cb)
}
Loading

0 comments on commit c461266

Please sign in to comment.