MariaDB Metadata skipping and DEPRECATE_EOF (#1708)
[MariaDB metadata skipping](https://mariadb.com/kb/en/mariadb-protocol-differences-with-mysql/#prepare-statement-skipping-metadata).
With this change, MariaDB server won't send metadata when they have not changed, saving client parsing metadata and network.
This feature rely on these changes:
* extended capabilities support
* EOF packet deprecation makes current implementation to be revised
A benchmark BenchmarkReceiveMetadata has been added to show the difference.
diff --git a/benchmark_test.go b/benchmark_test.go
index 1c3f64d..b246f4a 100644
--- a/benchmark_test.go
+++ b/benchmark_test.go
@@ -129,7 +129,7 @@
b.ReportAllocs()
b.ResetTimer()
- for i := 0; i < concurrencyLevel; i++ {
+ for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
@@ -400,7 +400,7 @@
}
args := make([]any, 200)
- for i := 1; i < 200; i+=2 {
+ for i := 1; i < 200; i += 2 {
args[i] = sval
}
for i := 0; i < 10000; i += 100 {
@@ -455,3 +455,58 @@
func BenchmarkReceive10kRowsCompressed(b *testing.B) {
benchmark10kRows(b, true)
}
+
+// BenchmarkReceiveMetadata measures performance of receiving lots of metadata compare to data in rows
+func BenchmarkReceiveMetadata(b *testing.B) {
+ tb := (*TB)(b)
+
+ // Create a table with 1000 integer fields
+ createTableQuery := "CREATE TABLE large_integer_table ("
+ for i := 0; i < 1000; i++ {
+ createTableQuery += fmt.Sprintf("col_%d INT", i)
+ if i < 999 {
+ createTableQuery += ", "
+ }
+ }
+ createTableQuery += ")"
+
+ // Initialize database
+ db := initDB(b, false,
+ "DROP TABLE IF EXISTS large_integer_table",
+ createTableQuery,
+ "INSERT INTO large_integer_table VALUES ("+
+ strings.Repeat("0,", 999)+"0)", // Insert a row of zeros
+ )
+ defer db.Close()
+
+ b.Run("query", func(b *testing.B) {
+ db.SetMaxIdleConns(0)
+ db.SetMaxIdleConns(1)
+
+ // Create a slice to scan all columns
+ values := make([]any, 1000)
+ valuePtrs := make([]any, 1000)
+ for j := range values {
+ valuePtrs[j] = &values[j]
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ // Prepare a SELECT query to retrieve metadata
+ stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1"))
+ defer stmt.Close()
+
+ // Benchmark metadata retrieval
+ for range b.N {
+ rows := tb.checkRows(stmt.Query())
+
+ rows.Next()
+ // Scan the row
+ err := rows.Scan(valuePtrs...)
+ tb.check(err)
+
+ rows.Close()
+ }
+ })
+}
diff --git a/connection.go b/connection.go
index 3e455a3..58c763f 100644
--- a/connection.go
+++ b/connection.go
@@ -33,7 +33,8 @@
connector *connector
maxAllowedPacket int
maxWriteSize int
- flags clientFlag
+ capabilities capabilityFlag
+ extCapabilities extendedCapabilityFlag
status statusFlag
sequence uint8
compressSequence uint8
@@ -223,13 +224,21 @@
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
- if err = mc.readUntilEOF(); err != nil {
+ if err = mc.skipColumns(stmt.paramCount); err != nil {
return nil, err
}
}
if columnCount > 0 {
- err = mc.readUntilEOF()
+ if mc.extCapabilities&clientCacheMetadata != 0 {
+ if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil {
+ return nil, err
+ }
+ } else {
+ if err = mc.skipColumns(int(columnCount)); err != nil {
+ return nil, err
+ }
+ }
}
}
@@ -370,19 +379,19 @@
}
// Read Result
- resLen, err := handleOk.readResultSetHeaderPacket()
+ resLen, _, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
- if err := mc.readUntilEOF(); err != nil {
+ if err := mc.skipColumns(resLen); err != nil {
return err
}
// rows
- if err := mc.readUntilEOF(); err != nil {
+ if err := mc.skipRows(); err != nil {
return err
}
}
@@ -419,7 +428,7 @@
// Read Result
var resLen int
- resLen, err = handleOk.readResultSetHeaderPacket()
+ resLen, _, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@@ -453,7 +462,7 @@
}
// Read Result
- resLen, err := handleOk.readResultSetHeaderPacket()
+ resLen, _, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
@@ -461,14 +470,14 @@
if resLen > 0 {
// Columns
- if err := mc.readUntilEOF(); err != nil {
+ if err := mc.skipColumns(resLen); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
- return dest[0].([]byte), mc.readUntilEOF()
+ return dest[0].([]byte), mc.skipRows()
}
}
return nil, err
diff --git a/connector.go b/connector.go
index bc1d46a..dca473f 100644
--- a/connector.go
+++ b/connector.go
@@ -131,7 +131,7 @@
mc.buf = newBuffer()
// Reading Handshake Initialization Packet
- authData, plugin, err := mc.readHandshakePacket()
+ authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
@@ -153,6 +153,7 @@
return nil, err
}
}
+ mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg)
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
@@ -167,7 +168,8 @@
return nil, err
}
- if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
+ // compression is enabled after auth, not right after sending handshake response.
+ if mc.capabilities&clientCompress > 0 {
mc.compress = true
mc.compIO = newCompIO(mc)
}
diff --git a/const.go b/const.go
index 4aadcd6..311e92e 100644
--- a/const.go
+++ b/const.go
@@ -42,11 +42,12 @@
iERR byte = 0xff
)
-// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
-type clientFlag uint32
+// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
+// https://mariadb.com/kb/en/connection/#capabilities
+type capabilityFlag uint32
const (
- clientLongPassword clientFlag = 1 << iota
+ clientMySQL capabilityFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
@@ -73,6 +74,18 @@
clientDeprecateEOF
)
+// https://mariadb.com/kb/en/connection/#capabilities
+type extendedCapabilityFlag uint32
+
+const (
+ progressIndicator extendedCapabilityFlag = 1 << iota
+ clientComMulti
+ clientStmtBulkOperations
+ clientExtendedMetadata
+ clientCacheMetadata
+ clientUnitBulkResult
+)
+
const (
comQuit byte = iota + 1
comInitDB
diff --git a/packets.go b/packets.go
index e6e1704..1319f9e 100644
--- a/packets.go
+++ b/packets.go
@@ -184,20 +184,22 @@
******************************************************************************/
// Handshake Initialization Packet
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
+// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
+// https://mariadb.com/kb/en/connection/#initial-handshake-packet
+func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) {
data, err = mc.readPacket()
if err != nil {
return
}
if data[0] == iERR {
- return nil, "", mc.handleErrorPacket(data)
+ err = mc.handleErrorPacket(data)
+ return
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
- return nil, "", fmt.Errorf(
+ return nil, 0, 0, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
@@ -215,15 +217,15 @@
pos += 8 + 1
// capability flags (lower 2 bytes) [2 bytes]
- mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
- if mc.flags&clientProtocol41 == 0 {
- return nil, "", ErrOldProtocol
+ capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+ if capabilities&clientProtocol41 == 0 {
+ return nil, capabilities, 0, "", ErrOldProtocol
}
- if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
+ if capabilities&clientSSL == 0 && mc.cfg.TLS != nil {
if mc.cfg.AllowFallbackToPlaintext {
mc.cfg.TLS = nil
} else {
- return nil, "", ErrNoTLS
+ return nil, capabilities, 0, "", ErrNoTLS
}
}
pos += 2
@@ -233,11 +235,16 @@
// status flags [2 bytes]
pos += 3
// capability flags (upper 2 bytes) [2 bytes]
- mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
+ capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2
// length of auth-plugin-data [1 byte]
- // reserved (all [00]) [10 bytes]
- pos += 11
+ // reserved (all [00]) [6 bytes]
+ pos += 7
+ if capabilities&clientMySQL == 0 {
+ // MariaDB server extended flag
+ extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4]))
+ }
+ pos += 4
// second part of the password cipher [minimum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -265,82 +272,72 @@
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], authData)
- return b[:], plugin, nil
+ return b[:], capabilities, extendedCapabilities, plugin, nil
}
// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], authData)
- return b[:], plugin, nil
+ return b[:], capabilities, 0, plugin, nil
+}
+
+// initCapabilities initializes the capabilities based on server support and configuration
+func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) {
+ clientCapabilities :=
+ clientMySQL |
+ clientLongFlag |
+ clientProtocol41 |
+ clientSecureConn |
+ clientTransactions |
+ clientPluginAuthLenEncClientData |
+ clientLocalFiles |
+ clientPluginAuth |
+ clientMultiResults |
+ clientConnectAttrs |
+ clientDeprecateEOF
+
+ if cfg.ClientFoundRows {
+ clientCapabilities |= clientFoundRows
+ }
+ if cfg.compress {
+ clientCapabilities |= clientCompress
+ }
+ // To enable TLS / SSL
+ if mc.cfg.TLS != nil {
+ clientCapabilities |= clientSSL
+ }
+
+ if mc.cfg.MultiStatements {
+ clientCapabilities |= clientMultiStatements
+ }
+ if n := len(cfg.DBName); n > 0 {
+ clientCapabilities |= clientConnectWithDB
+ }
+
+ // only keep client capabilities that server have
+ mc.capabilities = clientCapabilities & serverCapabilities
+
+ // set MariaDB extended clientCacheMetadata capability if server support it
+ mc.extCapabilities = clientCacheMetadata & serverExtCapabilities
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
- // Adjust client flags based on server support
- clientFlags := clientProtocol41 |
- clientSecureConn |
- clientLongPassword |
- clientTransactions |
- clientLocalFiles |
- clientPluginAuth |
- clientMultiResults |
- mc.flags&clientConnectAttrs |
- mc.flags&clientLongFlag
-
- sendConnectAttrs := mc.flags&clientConnectAttrs != 0
-
- if mc.cfg.ClientFoundRows {
- clientFlags |= clientFoundRows
- }
- if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
- clientFlags |= clientCompress
- }
- // To enable TLS / SSL
- if mc.cfg.TLS != nil {
- clientFlags |= clientSSL
- }
-
- if mc.cfg.MultiStatements {
- clientFlags |= clientMultiStatements
- }
-
- // encode length of the auth plugin data
- var authRespLEIBuf [9]byte
- authRespLen := len(authResp)
- authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
- if len(authRespLEI) > 1 {
- // if the length can not be written in 1 byte, it must be written as a
- // length encoded integer
- clientFlags |= clientPluginAuthLenEncClientData
- }
-
- pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
-
- // To specify a db name
- if n := len(mc.cfg.DBName); n > 0 {
- clientFlags |= clientConnectWithDB
- pktLen += n + 1
- }
-
- // encode length of the connection attributes
- var connAttrsLEI []byte
- if sendConnectAttrs {
- var connAttrsLEIBuf [9]byte
- connAttrsLen := len(mc.connector.encodedAttributes)
- connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
- pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)
- }
-
- // Calculate packet length and get buffer with that size
- data, err := mc.buf.takeBuffer(pktLen + 4)
+ // packet header 4
+ // capabilities 4
+ // maxPacketSize 4
+ // collation id 1
+ // filler 23
+ data, err := mc.buf.takeSmallBuffer(4*3 + 24)
if err != nil {
mc.cleanup()
return err
}
+ _ = data[4*3+23] // boundery check
- // ClientFlags [32 bit]
- binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags))
+ // clientCapabilities [32 bit]
+ binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities))
// MaxPacketSize [32 bit] (none)
binary.LittleEndian.PutUint32(data[8:], 0)
@@ -358,16 +355,26 @@
}
// Filler [23 bytes] (all 0x00)
+ // or filler 19bytes + mariadb extCapabilities
pos := 13
- for ; pos < 13+23; pos++ {
- data[pos] = 0
+ if mc.capabilities&clientMySQL == 0 {
+ for ; pos < 13+19; pos++ {
+ data[pos] = 0
+ }
+ // MariaDB Extended Capabilities
+ binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities))
+ } else {
+ for ; pos < 13+23; pos++ {
+ data[pos] = 0
+ }
}
// SSL Connection Request Packet
- // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html
+ // https://mariadb.com/kb/en/connection/#sslrequest-packet
if mc.cfg.TLS != nil {
// Send TLS / SSL request packet
- if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
+ if err := mc.writePacket(data); err != nil {
return err
}
@@ -384,34 +391,32 @@
// User [null terminated string]
if len(mc.cfg.User) > 0 {
- pos += copy(data[pos:], mc.cfg.User)
+ data = append(data, mc.cfg.User...)
}
- data[pos] = 0x00
- pos++
+ data = append(data, 0)
// Auth Data [length encoded integer]
- pos += copy(data[pos:], authRespLEI)
- pos += copy(data[pos:], authResp)
+ data = appendLengthEncodedInteger(data, uint64(len(authResp)))
+ data = append(data, authResp...)
- // Databasename [null terminated string]
- if len(mc.cfg.DBName) > 0 {
- pos += copy(data[pos:], mc.cfg.DBName)
- data[pos] = 0x00
- pos++
+ // Database name [null terminated string]
+ if mc.capabilities&clientConnectWithDB != 0 {
+ data = append(data, mc.cfg.DBName...)
+ data = append(data, 0)
}
- pos += copy(data[pos:], plugin)
- data[pos] = 0x00
- pos++
+ data = append(data, plugin...)
+ data = append(data, 0)
// Connection Attributes
- if sendConnectAttrs {
- pos += copy(data[pos:], connAttrsLEI)
- pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))
+ if mc.capabilities&clientConnectAttrs != 0 {
+ connAttrsLen := len(mc.connector.encodedAttributes)
+ data = appendLengthEncodedInteger(data, uint64(connAttrsLen))
+ data = append(data, mc.connector.encodedAttributes...)
}
// Send Auth packet
- return mc.writePacket(data[:pos])
+ return mc.writePacket(data)
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
@@ -546,32 +551,37 @@
// Result Set Header Packet
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html
-func (mc *okHandler) readResultSetHeaderPacket() (int, error) {
+func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)
data, err := mc.conn().readPacket()
if err != nil {
- return 0, err
+ return 0, false, err
}
switch data[0] {
case iOK:
- return 0, mc.handleOkPacket(data)
+ return 0, false, mc.handleOkPacket(data)
case iERR:
- return 0, mc.conn().handleErrorPacket(data)
+ return 0, false, mc.conn().handleErrorPacket(data)
case iLocalInFile:
- return 0, mc.handleInFileRequest(string(data[1:]))
+ return 0, false, mc.handleInFileRequest(string(data[1:]))
}
// column count
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
- num, _, _ := readLengthEncodedInteger(data)
+ // https://mariadb.com/kb/en/result-set-packets/#column-count-packet
+ num, _, len := readLengthEncodedInteger(data)
+
+ if mc.extCapabilities&clientCacheMetadata != 0 {
+ return int(num), data[len] == 0x01, nil
+ }
// ignore remaining data in the packet. see #1478.
- return int(num), nil
+ return int(num), true, nil
}
// Error Packet
@@ -695,20 +705,12 @@
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)
- for i := 0; ; i++ {
+ for i := range count {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
- // EOF Packet
- if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
- if i == count {
- return columns, nil
- }
- return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
- }
-
// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
@@ -781,13 +783,13 @@
// Decimals [uint8]
columns[i].decimals = data[pos]
- //pos++
-
- // Default value [len coded binary]
- //if pos < len(data) {
- // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
- //}
}
+
+ // skip EOF packet if client does not support deprecateEOF
+ if err := mc.skipEof(); err != nil {
+ return nil, err
+ }
+ return columns, nil
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -805,9 +807,20 @@
}
// EOF Packet
- if data[0] == iEOF && len(data) == 5 {
- // server_status [2 bytes]
- rows.mc.status = readStatus(data[3:])
+ // text row packets may starts with LengthEncodedString.
+ // In such case, 0xFE can mean string larger than 0xffffff.
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le
+ if data[0] == iEOF && len(data) <= 0xffffff {
+ if mc.capabilities&clientDeprecateEOF == 0 {
+ // Deprecated EOF packet
+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html
+ mc.status = readStatus(data[3:])
+ } else {
+ // Ok Packet with an 0xFE header
+ _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
+ _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
+ mc.status = readStatus(data[1+n+m:])
+ }
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
@@ -881,8 +894,34 @@
return nil
}
-// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
-func (mc *mysqlConn) readUntilEOF() error {
+func (mc *mysqlConn) skipPackets(n int) error {
+ for i := 0; i < n; i++ {
+ if _, err := mc.readPacket(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set
+func (mc *mysqlConn) skipEof() error {
+ if mc.capabilities&clientDeprecateEOF == 0 {
+ if _, err := mc.readPacket(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (mc *mysqlConn) skipColumns(n int) error {
+ if err := mc.skipPackets(n); err != nil {
+ return err
+ }
+ return mc.skipEof()
+}
+
+// Reads Packets until EOF-Packet or an Error appears.
+func (mc *mysqlConn) skipRows() error {
for {
data, err := mc.readPacket()
if err != nil {
@@ -893,10 +932,20 @@
case iERR:
return mc.handleErrorPacket(data)
case iEOF:
- if len(data) == 5 {
- mc.status = readStatus(data[3:])
+ // text row packets may starts with LengthEncodedString.
+ // In such case, 0xFE can mean string larger than 0xffffff.
+ if len(data) <= 0xffffff {
+ if mc.capabilities&clientDeprecateEOF == 0 {
+ // EOF packet
+ mc.status = readStatus(data[3:])
+ } else {
+ // OK packet with an 0xFE header
+ _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows
+ _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id
+ mc.status = readStatus(data[1+n+m:])
+ }
+ return nil
}
- return nil
}
}
}
@@ -1184,17 +1233,17 @@
// mc.affectedRows and mc.insertIds.
func (mc *okHandler) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
- resLen, err := mc.readResultSetHeaderPacket()
+ resLen, _, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
- if err := mc.conn().readUntilEOF(); err != nil {
+ if err := mc.conn().skipColumns(resLen); err != nil {
return err
}
// rows
- if err := mc.conn().readUntilEOF(); err != nil {
+ if err := mc.conn().skipRows(); err != nil {
return err
}
}
@@ -1211,9 +1260,17 @@
// packet indicator [1 byte]
if data[0] != iOK {
- // EOF Packet
- if data[0] == iEOF && len(data) == 5 {
- rows.mc.status = readStatus(data[3:])
+ // EOF/OK Packet
+ if data[0] == iEOF {
+ if rows.mc.capabilities&clientDeprecateEOF == 0 {
+ // EOF packet
+ rows.mc.status = readStatus(data[3:])
+ } else {
+ // OK Packet with an 0xFE header
+ _, _, n := readLengthEncodedInteger(data[1:])
+ _, _, m := readLengthEncodedInteger(data[1+n:])
+ rows.mc.status = readStatus(data[1+n+m:])
+ }
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
diff --git a/packets_test.go b/packets_test.go
index 694b056..b487051 100644
--- a/packets_test.go
+++ b/packets_test.go
@@ -332,11 +332,19 @@
112, 97, 115, 115, 119, 111, 114, 100}
conn.maxReads = 1
- authData, pluginName, err := mc.readHandshakePacket()
+ authData, serverCapabilities, serverExtendedCapabilities, pluginName, err := mc.readHandshakePacket()
if err != nil {
t.Fatalf("got error: %v", err)
}
+ if serverCapabilities != 2148530143 {
+ t.Fatalf("expected serverCapabilities to be 2148530143, got %v", serverCapabilities)
+ }
+
+ if serverExtendedCapabilities != 0 {
+ t.Fatalf("expected serverExtendedCapabilities to be 0, got %v", serverExtendedCapabilities)
+ }
+
if pluginName != "mysql_native_password" {
t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
}
diff --git a/rows.go b/rows.go
index df98417..e41fda6 100644
--- a/rows.go
+++ b/rows.go
@@ -113,7 +113,7 @@
// Remove unread packets from stream
if !rows.rs.done {
- err = mc.readUntilEOF()
+ err = mc.skipRows()
}
if err == nil {
handleOk := mc.clearResult()
@@ -143,7 +143,7 @@
// Remove unread packets from stream
if !rows.rs.done {
- if err := rows.mc.readUntilEOF(); err != nil {
+ if err := rows.mc.skipRows(); err != nil {
return 0, err
}
rows.rs.done = true
@@ -156,7 +156,7 @@
rows.rs = resultSet{}
// rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to
// nextResultSet.
- resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
+ resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket()
if err != nil {
// Clean up about multi-results flag
rows.rs.done = true
diff --git a/statement.go b/statement.go
index 35df854..0f6c65a 100644
--- a/statement.go
+++ b/statement.go
@@ -20,6 +20,7 @@
mc *mysqlConn
id uint32
paramCount int
+ columns []mysqlField
}
func (stmt *mysqlStmt) Close() error {
@@ -64,19 +65,26 @@
handleOk := stmt.mc.clearResult()
// Read Result
- resLen, err := handleOk.readResultSetHeaderPacket()
+ resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
if resLen > 0 {
// Columns
- if err = mc.readUntilEOF(); err != nil {
- return nil, err
+ if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 {
+ // we can not skip column metadata because next stmt.Query() may use it.
+ if stmt.columns, err = mc.readColumns(resLen); err != nil {
+ return nil, err
+ }
+ } else {
+ if err = mc.skipColumns(resLen); err != nil {
+ return nil, err
+ }
}
// Rows
- if err := mc.readUntilEOF(); err != nil {
+ if err = mc.skipRows(); err != nil {
return nil, err
}
}
@@ -107,7 +115,7 @@
// Read Result
handleOk := stmt.mc.clearResult()
- resLen, err := handleOk.readResultSetHeaderPacket()
+ resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
@@ -116,7 +124,17 @@
if resLen > 0 {
rows.mc = mc
- rows.rs.columns, err = mc.readColumns(resLen)
+ if metadataFollows {
+ if rows.rs.columns, err = mc.readColumns(resLen); err != nil {
+ return nil, err
+ }
+ stmt.columns = rows.rs.columns
+ } else {
+ if err = mc.skipEof(); err != nil {
+ return nil, err
+ }
+ rows.rs.columns = stmt.columns
+ }
} else {
rows.rs.done = true