Validation on Google’s TPU-v3

Senario - 1: v3-8 nodes

Hardware Setup

  1. 8 TPUs connected in a bidirectional ring

  2. Each of the 2 links are 80 GB/s

  3. Jax All-reduce is based on NCCL ring Algorithm

ASTRA-Sim setup

  1. Modelled with the ASTRA-Sim Analytical Backend

  2. Bidirectional Ring

  3. Link Latency is 1 us [as per TPU research team]

Collectives run

  1. All-Reduce & All-gather

  2. Reduction operation - Sum

Results

All Reduce results:

Alt text

All Gather results:

Alt text

Senario - 1: v3-32 nodes

Hardware Setup

  1. 32 TPUs connected in a Mesh

  2. Each of the 2 (or 3) links are 80 GB/s

  3. Jax All-reduce is based on NCCL ring Algorithm

ASTRA-Sim setup

  1. Modelled with the ASTRA-Sim Analytical Backend

  2. Bidirectional Ring and 2D Torus

  3. Link Latency is 1 us [as per TPU research team]

Collectives run

  1. All-Reduce

  2. Reduction operation - Sum

Results

TPU v3-32 modelled as a Ring:

Alt text

TPU v3-32 modelled as a 2D-Torus:

Alt text