0%

TCP粘包拆包问题

TCP粘包拆包问题极其解决方案

TCP协议的特性

TCP协议是面向字节流的协议。TCP中的“流”(stream)指的是流入到进程或从进程流出的字节序列。
面向字节流的含义是:虽然应用程序和TCP的交互是一次一个数据块(大小不等),但是,TCP把应用程序交付下来的数据仅仅看成是一串无结构的字节流,TCP并不知道所传送的字节流的含义。对于应用程序来说,它看到的数据之间没有边界,也无法得知一条报文从哪里开始,到哪里结束,每条报文有多少字节。
而UDP是面向消息的协议,每个UDP段都是一条消息,应用程序必须以消息为单位提取数据,不能一次提取任意字节的数据。因此,UDP通信不会发生粘包问题。

导致粘包的情况

连续发送较短数据

在发送数据时,TCP会根据nagle算法,将数据量小的,且时间间隔较短的数据一次性发给对方。也就是说,如果发送端连续发送了好几个数据包,经过nagle算法的优化,这些小的数据包就可能被封装成一个大的数据包,一次性发送给接收端,而TCP是面向字节流的通信,没有消息保护边界,所以就产生了粘包问题。

接收端没有及时接收数据

还有一种情况会产生粘包,就是接收方没有及时接收数据。可能发送端发来了一段数据,但接收端只接收了部分数据,剩下的小部分数据还遗留在接收缓冲区。那么在下一次接收时,接收缓冲区上不但有上一次遗留的数据,还可能有来自其它报文数据,它们作为一个整体被接收端接收了,这就也造成了粘包。
综上所述,在接收缓冲区上的粘包表现形式主要有以下三种:

  1. packet1和packet2倍合并在一起,一起发送
  2. packet1发生了拆包,packet2与packet1的部分数据被合并起来一起发送
  3. packet2的部分数据没有被及时接收,留在缓冲区与packet1合并起来一起发送

导致拆包的情况

如果发送端缓冲区的长度大于网卡的MTU时,TCP会将这次发送的数据拆成几个数据包发送出去。也就是说,发送端可能只发送了一次数据,接收端却要分好几次才能收到完整的数据

自定义协议

虽然我们无法决定TCP会如何处理发送端发出来的数据,但我们可以借助序列化和反序列化的思想,人为地为数据添加边界。比如,在发送端给待发送的数据加上自定义协议作为报文头,在接收端接收数据时,再把数据还原成我们想要的样子。

举个例子,报文头可以按如下方式构造:

head 协议头 + cmd 控制码 + len 报文数据长度 + crc 校验码 + data 报文数据

  • 协议头(head)是我们在接收缓冲区中识别本程序所需报文的基本依据;
  • 控制码(cmd)用来标识程序中不同报文的作用;
  • 报文数据长度(len)是一个数据报中真实数据的长度,当然也可以是一个数据报的完整长度;
  • 校验码(crc)一般在head、cmd、len的基础上生成,为了进一步确保之前通过协议头判断的数据报是我们要的,(单凭协议头判断目标报文是否存在是不够严谨的,报数据部分也有出现协议头序列的可能)
  • 报文数据(data)就是发送端真正需要发送的数据,也是接收端经过反序列化后,需要得到的数据。

解析报文

接收端调用一次readAll(),会把当前接收缓冲区上的所有数据读取出来,这时候的接收缓冲区上可能有如下几种情况:

不包含协议头

对于这种情况,还需要进行进一步判断:

  • 是目标报文的数据部分。说明数据发送的过程中发生了拆包,需要多次接收数据,直到所接收的数据总长度达到协议中指定的报文数据长度。
  • 不是目标报文的数据部分。可以直接丢弃。

包含一个或多个完整的协议头

对于这种情况,还需要进行进一步判断:

  • 能通过CRC校验。说明当前缓冲区发生了粘包,需要进行循环处理。
  • 不能通过CRC校验。也分两种情况:
  1. 如果读取到的是长度完整的协议,但仍不通过CRC校验,说明当前缓冲区的数据不是目标报文的数据,可以直接丢弃。
    
  2. 如果读取到的是长度不完整的协议(比如协议在拆包时被截断),才导致没有通过CRC校验,就不能判断接下来读取的数据不是目标报文的数据。【对于这种情况,要特殊处理】
    

包含不完整的协议头

对于这种情况,首先肯定是无法通过是否包含协议头的判断的,但如果直接丢弃这段数据,就会造成丢包。所以,要防止这种情况的发生,就要防止在读取数据时,读取过短的数据。
知道缓冲区可能有以上这几种情况,有助于我们对症下药,下面来梳理接收端处理数据的流程:

解析报文的流程:

代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
   package main

import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
)

// 常量定义
const (
PRIVATE_HEAD uint32 = 0x55AA55AA // 私有头部标识
PROTOCOL_LENGTH uint32 = 16 // 协议长度
)

// BufferData 结构体,用于存储接收到的数据
type BufferData struct {
recevData []byte
hasHead bool
totalLen uint32
}

// CProtocalData 结构体,用于存储解析后的协议数据
type CProtocalData struct {
Header uint32
Cmd uint32
Len uint32
Crc uint32
Data []byte
}

// CDataRecver 结构体,用于接收和处理数据
type CDataRecver struct {
buff BufferData
recvDataVector []CProtocalData
dataArrivedChan chan []byte
}

// NewCDataRecver 创建一个新的 CDataRecver 实例
func NewCDataRecver() *CDataRecver {
return &CDataRecver{
dataArrivedChan: make(chan []byte, 100), // 使用缓冲通道
}
}

// SlotDataArrived 处理到达的数据
func (c *CDataRecver) SlotDataArrived(array []byte) {
c.buff.recevData = append(c.buff.recevData, array...)
c.checkBufferHasHead(&c.buff)

size := uint32(len(c.buff.recevData))
if size >= c.buff.totalLen {
c.parseBufferData(&c.buff, &c.recvDataVector)
}

log.Printf("current recv: %d, total size: %d", len(array), len(c.buff.recevData))
}

// checkBufferHasHead 检查缓冲区是否包含报文头
func (c *CDataRecver) checkBufferHasHead(bufferData *BufferData) {
if bufferData.hasHead {
return
}

index := bytes.Index(bufferData.recevData, uint32ToBytes(PRIVATE_HEAD))
if index == -1 {
bufferData.recevData = bufferData.recevData[:0]
return
}

if index > 0 {
bufferData.recevData = bufferData.recevData[index:]
}

if len(bufferData.recevData) < int(PROTOCOL_LENGTH) {
return
}

header, cmd, length, crc, _ := parseProtocolData(bufferData.recevData)

if !checkCRC(header, cmd, length, crc) {
log.Println("wrong crc")
bufferData.recevData = bufferData.recevData[:0]
bufferData.totalLen = 0
bufferData.hasHead = false
return
}

bufferData.hasHead = true
bufferData.totalLen = length
}

// checkCRC 进行CRC校验
func checkCRC(header, cmd, length, crc uint32) bool {
rightCRC := header + cmd + length - PROTOCOL_LENGTH
return rightCRC == crc
}

// parseBufferData 解析缓冲区的数据
func (c *CDataRecver) parseBufferData(bufferData *BufferData, vector *[]CProtocalData) {
for {
index := bytes.Index(bufferData.recevData, uint32ToBytes(PRIVATE_HEAD))
if index == -1 || len(bufferData.recevData) == 0 {
return
}

if index > 0 {
bufferData.recevData = bufferData.recevData[index:]
}

if len(bufferData.recevData) < int(PROTOCOL_LENGTH) {
break
}

header, cmd, length, crc, data := parseProtocolData(bufferData.recevData)

if !checkCRC(header, cmd, length, crc) {
log.Println("wrong crc")
bufferData.recevData = bufferData.recevData[:0]
bufferData.hasHead = false
bufferData.totalLen = 0
break
}

dataSize := uint32(len(data)) + PROTOCOL_LENGTH
if length > dataSize {
break
}

*vector = append(*vector, CProtocalData{
Header: header,
Cmd: cmd,
Len: length,
Crc: crc,
Data: data,
})

bufferData.recevData = bufferData.recevData[dataSize:]
bufferData.hasHead = false

if len(bufferData.recevData) < 4 {
break
}
}
}

// parseProtocolData 解析协议数据
func parseProtocolData(data []byte) (header, cmd, length, crc uint32, protocolData []byte) {
reader := bytes.NewReader(data)
binary.Read(reader, binary.BigEndian, &header)
binary.Read(reader, binary.BigEndian, &cmd)
binary.Read(reader, binary.BigEndian, &length)
binary.Read(reader, binary.BigEndian, &crc)
protocolData = make([]byte, length-PROTOCOL_LENGTH)
io.ReadFull(reader, protocolData)
return
}

// uint32ToBytes 将uint32转换为字节切片
func uint32ToBytes(n uint32) []byte {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
return b
}

func main() {
receiver := NewCDataRecver()

// 模拟数据到达
go func() {
testData := []byte{0x55, 0xAA, 0x55, 0xAA, 0, 0, 0, 1, 0, 0, 0, 20, 0, 0, 0, 60, 1, 2, 3, 4}
receiver.dataArrivedChan <- testData
}()

// 处理到达的数据
for data := range receiver.dataArrivedChan {
receiver.SlotDataArrived(data)
}

fmt.Println("Received data:", receiver.recvDataVector)
}