Update src/parallelism_utils.py
Browse files- src/parallelism_utils.py +6 -12
src/parallelism_utils.py
CHANGED
|
@@ -9,10 +9,10 @@ def get_precision_fac(precision: str):
|
|
| 9 |
raise ValueError("Precision must be either 'mixed' or 'single'")
|
| 10 |
|
| 11 |
|
| 12 |
-
def get_params_fac(model_dtype:
|
| 13 |
-
if model_dtype ==
|
| 14 |
return 2
|
| 15 |
-
elif model_dtype ==
|
| 16 |
return 4
|
| 17 |
else:
|
| 18 |
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
|
|
@@ -29,19 +29,13 @@ FP32_PARAM_FACTOR = 4
|
|
| 29 |
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
|
| 30 |
|
| 31 |
|
| 32 |
-
# TODO: check if params_fac is needed during full fp32 training.
|
| 33 |
-
# Normally, mixed precision training results in 1.5x memory compared to FP32.
|
| 34 |
-
# Currently, we are assuming 2x memory for FP32, as deepspeed's ZeRO-2 is optimized for FP16 training.
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
def estimate_zero1_model_states_mem_needs(total_params,
|
| 39 |
num_gpus_per_node=1,
|
| 40 |
num_nodes=1,
|
| 41 |
cpu_offload=True,
|
| 42 |
additional_buffer_factor=1.5,
|
| 43 |
precision="mixed",
|
| 44 |
-
model_dtype =
|
| 45 |
):
|
| 46 |
|
| 47 |
total_gpus = num_nodes * num_gpus_per_node
|
|
@@ -68,7 +62,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
|
|
| 68 |
cpu_offload=True,
|
| 69 |
additional_buffer_factor=1.5,
|
| 70 |
precision="mixed",
|
| 71 |
-
model_dtype =
|
| 72 |
):
|
| 73 |
|
| 74 |
total_gpus = num_nodes * num_gpus_per_node
|
|
@@ -98,7 +92,7 @@ def estimate_zero3_model_states_mem_needs(total_params,
|
|
| 98 |
zero_init=True,
|
| 99 |
additional_buffer_factor=1.5,
|
| 100 |
precision="mixed",
|
| 101 |
-
model_dtype =
|
| 102 |
):
|
| 103 |
|
| 104 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
|
| 9 |
raise ValueError("Precision must be either 'mixed' or 'single'")
|
| 10 |
|
| 11 |
|
| 12 |
+
def get_params_fac(model_dtype: str):
|
| 13 |
+
if model_dtype == "float16":
|
| 14 |
return 2
|
| 15 |
+
elif model_dtype == "float32":
|
| 16 |
return 4
|
| 17 |
else:
|
| 18 |
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
|
|
|
|
| 29 |
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def estimate_zero1_model_states_mem_needs(total_params,
|
| 33 |
num_gpus_per_node=1,
|
| 34 |
num_nodes=1,
|
| 35 |
cpu_offload=True,
|
| 36 |
additional_buffer_factor=1.5,
|
| 37 |
precision="mixed",
|
| 38 |
+
model_dtype = "float16",
|
| 39 |
):
|
| 40 |
|
| 41 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
|
| 62 |
cpu_offload=True,
|
| 63 |
additional_buffer_factor=1.5,
|
| 64 |
precision="mixed",
|
| 65 |
+
model_dtype = "float16",
|
| 66 |
):
|
| 67 |
|
| 68 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
|
| 92 |
zero_init=True,
|
| 93 |
additional_buffer_factor=1.5,
|
| 94 |
precision="mixed",
|
| 95 |
+
model_dtype = "float16",
|
| 96 |
):
|
| 97 |
|
| 98 |
total_gpus = num_nodes * num_gpus_per_node
|