WaitGroup
用来同步程序中的多个协程,等待集合中的多个协程完成。
0x01 WaitGroup的使用
WaitGroup的使用非常简单,示例如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import "sync"
func main(){
var wg = sync.WaitGroup{}
for i := 0; i < 5; i++ {
wg.Add(1)
go func(){
wg.Done()
}()
}
wg.Wait()
}
|
使用方法很简单:
- 声明一个
WaitGroup
变量 wg
- 在协程调用前,调用
wg.Add(1)
,在协程退出前,调用 wg.Done()
- 在需要等待协程结束的地方,调用
wg.Wait()
WaitGroup的底层结构
在Go的1.4版本中,支持了 WaitGroup
,当时的结构体如下:
1
2
3
4
5
6
|
type WaitGroup struct {
m Mutex
counter int32
waiters int32
sema *uint32
}
|
可以看到,WaitGroup
中,主要有这4个字段,即一个互斥锁(m
),一个计数器(counter
),一个等待计数(waiters
)和一个信号量(*uint32
)。后续的版本,也是这些变量的变种。
在 Go1.7及以后,使用 noCopy
代替了 Muxte
Go1.19 中,WaitGroup的结构体如下:
1
2
3
4
5
|
type WaitGroup struct{
noCopy noCopy
state1 uint64
state2 uint32
}
|
在最新版本 Go1.22的结构体如下
1
2
3
4
5
6
|
type WaitGroup struct {
noCopy noCopy
state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
sema uint32
}
|
最新版本中的 counter
和 waiters
由一个变量 state
来代替,主要是为了减少加锁造成的开销。正如注释中所讲的,高32位表示 counter
,低32位表示 waiter
计数。
在 Go1.19版本中,内存模型如下
对应的处理方法如下:
1
2
3
4
5
6
7
|
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
return &wg.state1, &wg.state2
} else {
return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
}
}
|
即,如果内存是按 64 位对齐,则直接返回接收器 wg
中的 state1 和 state2 字段。
如果内存是按 32 位对齐,则将 state1 转化为一个包含三个元素的数组。后两个元素代表 waiter和counter,第一个元素是 sema。
通过以上可以知道:waiter和counter始终在一块,且保持先后顺序。
在最新版本中,直接舍弃 state()
方法。直接取 state
的高32位和低32位。
0x03 WaitGroup的使用逻辑
相关代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
func (wg *WaitGroup) Add(delta int) {
...
state := wg.state.Add(uint64(delta) << 32) // 高 32 位加 delta
v := int32(state >> 32) // v, 也就是 counter 的值,为 state 的高32位
w := uint32(state) // w,也就是 waiter 的值,为 state 的低32位
...
if v > 0 || w == 0 {
return
}
...
wg.state.Store(0)
for ; w != 0; w-- {
runtime_Semrelease(&wg.sema, false, 0)
}
}
func (wg *WaitGroup) Wait(){
...
for {
state := wg.state.Load()
v := int32(state >> 32)
w := uint32(state)
if v == 0 {
return
}
if wg.state.CompareAndSwap(state, state+1) {
runtime_Semacquire(&wg.sema)
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
return
}
}
}
|
在 WaitGroup
中,会用到 sync 的两个函数:runtime_Semrelease()
和 runtime_Semacquire()
,这两个函数的详细说明,参见 $GOROOT/src/sync/runtime.go
,大致功能为:
Semrelease
原子地增加 s 的值,然后通知因调用 Semacquire
阻塞而等待的协程
Semacquire
等待 s 的值大于0,
使用 WaitGroup
的大致流程如下:
- 调用
Add(n)
,会让 counter 加 n ,
- 调用
Done()
,内部调用的实际上是 Add(-1)
,让 counter 的值减一。当 count 的值减少为 0 时,减少 waiter
的值。直到为0,然后调用 runtime_Semrelease
, 翻译信号量,此时在第 3 步的阻塞会收到信号量,并返回。
- 调用
Wait()
,让 sema
的值加1,然后调用 runtime_Semacquire(semap)
,请求信号量,并阻塞。