WaitGroup 用来同步程序中的多个协程,等待集合中的多个协程完成。

0x01 WaitGroup的使用

WaitGroup的使用非常简单,示例如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import "sync"

func main(){
	var wg = sync.WaitGroup{}

	for i := 0; i < 5; i++ {
		wg.Add(1)
		go func(){
			wg.Done()
		}()
	}
	wg.Wait()
}

使用方法很简单:

  1. 声明一个 WaitGroup 变量 wg
  2. 在协程调用前,调用 wg.Add(1) ,在协程退出前,调用 wg.Done()
  3. 在需要等待协程结束的地方,调用 wg.Wait()

WaitGroup的底层结构

Go的1.4版本中,支持了 WaitGroup,当时的结构体如下:

1
2
3
4
5
6
type WaitGroup struct {
	m       Mutex
	counter int32
	waiters int32
	sema    *uint32
}

可以看到,WaitGroup 中,主要有这4个字段,即一个互斥锁(m),一个计数器(counter),一个等待计数(waiters)和一个信号量(*uint32)。后续的版本,也是这些变量的变种。 在 Go1.7及以后,使用 noCopy 代替了 Muxte

Go1.19 中,WaitGroup的结构体如下:

1
2
3
4
5
type WaitGroup struct{
	noCopy noCopy
	state1 uint64
	state2 uint32
}

在最新版本 Go1.22的结构体如下

1
2
3
4
5
6
type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

最新版本中的 counterwaiters 由一个变量 state 来代替,主要是为了减少加锁造成的开销。正如注释中所讲的,高32位表示 counter,低32位表示 waiter 计数。

在 Go1.19版本中,内存模型如下 Golang_WaitGroup_struct 对应的处理方法如下:

1
2
3
4
5
6
7
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {  
   if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {  
      return &wg.state1, &wg.state2  
   } else {  
      return (*uint64)(unsafe.Pointer(&state[1])), &state[0]  
   }  
}

即,如果内存是按 64 位对齐,则直接返回接收器 wg 中的 state1state2 字段。 如果内存是按 32 位对齐,则将 state1 转化为一个包含三个元素的数组。后两个元素代表 waiter和counter,第一个元素是 sema。 通过以上可以知道:waiter和counter始终在一块,且保持先后顺序。

在最新版本中,直接舍弃 state() 方法。直接取 state 的高32位和低32位。

0x03 WaitGroup的使用逻辑

相关代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
func (wg *WaitGroup) Add(delta int) {
	...
	state := wg.state.Add(uint64(delta) << 32) // 高 32 位加 delta
	v := int32(state >> 32) // v, 也就是 counter 的值,为 state 的高32位
	w := uint32(state) // w,也就是 waiter 的值,为 state 的低32位
	...
	if v > 0 || w == 0 {
		return
	}
	...
	wg.state.Store(0)
	for ; w != 0; w-- {
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

func (wg *WaitGroup) Wait(){
	...
	for {
		state := wg.state.Load()
		v := int32(state >> 32)
		w := uint32(state)
		if v == 0 {
			return
		}
		if wg.state.CompareAndSwap(state, state+1) {
			runtime_Semacquire(&wg.sema)
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}

WaitGroup 中,会用到 sync 的两个函数:runtime_Semrelease()runtime_Semacquire(),这两个函数的详细说明,参见 $GOROOT/src/sync/runtime.go,大致功能为:

  1. Semrelease 原子地增加 s 的值,然后通知因调用 Semacquire 阻塞而等待的协程
  2. Semacquire 等待 s 的值大于0,

使用 WaitGroup 的大致流程如下:

  1. 调用 Add(n),会让 counter 加 n ,
  2. 调用 Done(),内部调用的实际上是 Add(-1) ,让 counter 的值减一。当 count 的值减少为 0 时,减少 waiter 的值。直到为0,然后调用 runtime_Semrelease, 翻译信号量,此时在第 3 步的阻塞会收到信号量,并返回。
  3. 调用 Wait(),让 sema 的值加1,然后调用 runtime_Semacquire(semap),请求信号量,并阻塞。 Golang_waitgroup_core.png