optimize readPacket (#1705)
Avoid unnecessary allocation.
diff --git a/buffer.go b/buffer.go
index a653243..f895e87 100644
--- a/buffer.go
+++ b/buffer.go
@@ -42,6 +42,11 @@
return len(b.buf) > 0
}
+// len returns how many bytes in the read buffer.
+func (b *buffer) len() int {
+ return len(b.buf)
+}
+
// fill reads into the read buffer until at least _need_ bytes are in it.
func (b *buffer) fill(need int, r readerFunc) error {
// we'll move the contents of the current buffer to dest before filling it.
@@ -86,17 +91,10 @@
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
-func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) {
- if len(b.buf) < need {
- // refill
- if err := b.fill(need, r); err != nil {
- return nil, err
- }
- }
-
- data := b.buf[:need]
+func (b *buffer) readNext(need int) []byte {
+ data := b.buf[:need:need]
b.buf = b.buf[need:]
- return data, nil
+ return data
}
// takeBuffer returns a buffer with the requested size.
diff --git a/compress.go b/compress.go
index fa42772..e247a65 100644
--- a/compress.go
+++ b/compress.go
@@ -84,9 +84,9 @@
c.buff.Reset()
}
-func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) {
+func (c *compIO) readNext(need int) ([]byte, error) {
for c.buff.Len() < need {
- if err := c.readCompressedPacket(r); err != nil {
+ if err := c.readCompressedPacket(); err != nil {
return nil, err
}
}
@@ -94,8 +94,8 @@
return data[:need:need], nil // prevent caller writes into c.buff
}
-func (c *compIO) readCompressedPacket(r readerFunc) error {
- header, err := c.mc.buf.readNext(7, r) // size of compressed header
+func (c *compIO) readCompressedPacket() error {
+ header, err := c.mc.readNext(7)
if err != nil {
return err
}
@@ -103,7 +103,7 @@
// compressed header structure
comprLength := getUint24(header[0:3])
- compressionSequence := uint8(header[3])
+ compressionSequence := header[3]
uncompressedLength := getUint24(header[4:7])
if debug {
fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n",
@@ -120,7 +120,7 @@
c.mc.sequence = compressionSequence + 1
c.mc.compressSequence = c.mc.sequence
- comprData, err := c.mc.buf.readNext(comprLength, r)
+ comprData, err := c.mc.readNext(comprLength)
if err != nil {
return err
}
diff --git a/packets.go b/packets.go
index a497a50..e6e1704 100644
--- a/packets.go
+++ b/packets.go
@@ -25,19 +25,30 @@
// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html
// https://mariadb.com/kb/en/clientserver-protocol/
+// read n bytes from mc.buf
+func (mc *mysqlConn) readNext(n int) ([]byte, error) {
+ if mc.buf.len() < n {
+ err := mc.buf.fill(n, mc.readWithTimeout)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return mc.buf.readNext(n), nil
+}
+
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var prevData []byte
invalidSequence := false
- readNext := mc.buf.readNext
+ readNext := mc.readNext
if mc.compress {
readNext = mc.compIO.readNext
}
for {
// read packet header
- data, err := readNext(4, mc.readWithTimeout)
+ data, err := readNext(4)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {
@@ -85,7 +96,7 @@
}
// read packet body [pktLen bytes]
- data, err = readNext(pktLen, mc.readWithTimeout)
+ data, err = readNext(pktLen)
if err != nil {
mc.close()
if cerr := mc.canceled.Value(); cerr != nil {