本文使用的 DLRover 版本是 0.3.7
1. DLRover Operator
1.1 启动 ElasticJob 和 ScalePlan 的控制器
实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| // 创建 ElasticJob 的控制器
if err = controllers.NewElasticJobReconciler(mgr, masterImage).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "ElasticJob")
os.Exit(1)
}
// 创建 ScalePlan 的控制器
if err = controllers.NewScalePlanReconciler(mgr).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "ScalePlan")
os.Exit(1)
}
// 启动控制器
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
setupLog.Error(err, "problem running manager")
os.Exit(1)
}
|
这部分代码是使用 Kubebuilder 生成 Operator 框架时,自动生成的。
1.2 ElasticJob 控制器
实现代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
| switch job.Status.Phase {
case "", commonv1.JobCreated:
// 创建一个 Master Pod
r.initializeJob(job)
err := r.createEasydlMaster(job)
if err != nil {
logger.Warningf("Fail to create EasyDL Master")
return ctrl.Result{RequeueAfter: defaultPollInterval}, err
}
r.syncJobStateByReplicas(job)
case commonv1.JobPending:
r.syncJobStateByReplicas(job)
case commonv1.JobRunning:
r.handleFaultPods(job)
r.syncJobStateByReplicas(job)
case commonv1.JobScaling:
scalePlan, err := r.getJobScalePlan(job)
if err != nil {
logger.Errorf("Job %s: Fail to get scaleplan: %s", job.Name, err)
}
if scalePlan.Status.Phase != commonv1.JobPending {
logger.Infof("Job %s: Skip a %s scaleplan %s.", job.Name, scalePlan.Status.Phase, scalePlan.Name)
return ctrl.Result{}, nil
}
r.updateScalePlanScaling(scalePlan)
if scalePlan != nil {
err := r.executeScaling(job, scalePlan)
if err != nil {
logger.Errorf("Job %s: Fail to execute scaleplan %s: %s", job.Name, scalePlan.Name, err)
}
}
r.syncJobStateByReplicas(job)
case commonv1.JobSucceeded:
r.syncJobStateByReplicas(job)
r.stopRunningPods(job)
case commonv1.JobFailed:
logger.Infof("Job %s failed", job.Name)
r.syncJobStateByReplicas(job)
r.stopRunningPods(job)
default:
logger.Warningf("job %s unknown status %s", job.Name, job.Status.Phase)
}
return ctrl.Result{}, nil
|
虽然有很多的 case 判断,但主要在做两件事:
- 初始化时,创建 DLRover Master Pod
- 同步状态,将 ElasticJob 的状态同步到 ScalePlan、Pod 上
1.3 ScalePlan 控制器
代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| func (r *ScalePlanReconciler) updateJobToScaling(
scalePlan *elasticv1alpha1.ScalePlan,
job *elasticv1alpha1.ElasticJob,
pollInterval time.Duration) (ctrl.Result, error) {
if scalePlan.Status.Phase != commonv1.JobCreated && scalePlan.Status.Phase != commonv1.JobPending {
logger.Infof("Skip a %s ScalePlan %s", scalePlan.Status.Phase, scalePlan.Name)
return ctrl.Result{}, nil
}
job.Status.ScalePlan = scalePlan.Name
for taskType, resourceSpec := range scalePlan.Spec.ReplicaResourceSpecs {
if job.Status.ReplicaStatuses[taskType].Initial == 0 {
job.Status.ReplicaStatuses[taskType].Initial = int32(resourceSpec.Replicas)
}
}
msg := fmt.Sprintf("ElasticJob %s is scaling by %s with status %s.", job.Name, scalePlan.Name, scalePlan.Status.Phase)
logger.Infof(msg)
if scalePlan.Status.Phase == commonv1.JobCreated {
scalePlan.Status.Phase = commonv1.JobPending
err := updateScalePlanStatus(r.Client, scalePlan)
if err != nil {
return ctrl.Result{RequeueAfter: pollInterval}, err
}
}
common.UpdateStatus(&job.Status, commonv1.JobScaling, common.JobScalingReason, msg)
err := updateElasticJobStatus(r.Client, job)
if err != nil {
logger.Errorf("Failed to update job %s status to scaling with %s, err: %v", job.Name, scalePlan.Name, err)
return ctrl.Result{RequeueAfter: pollInterval}, err
}
return ctrl.Result{}, nil
}
|
主要的逻辑是:
- 将 scalePlan 关联到 ElasticJob 对象的 Status 中
- 更新 ScalePlan 和 ElasticJob 的状态
1.4 DLRover Master 的启动入口
从前面的代码可以看到 Operator 除了创建出一个 Master Pod 之外,主要是各种状态同步和关联,根本没有我们想要的容错逻辑。
下面是创建 DLRover Master Pod 的模板配置:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
| func NewMasterTemplateToJob(job *elasticv1alpha1.ElasticJob, masterImage string) {
command := masterCommand + fmt.Sprintf(
" --platform pyk8s --namespace %s --job_name %s --port %d",
job.Namespace, job.Name, masterServicePort,
)
container := corev1.Container{
Name: "main",
Image: masterImage,
ImagePullPolicy: defaultImagePullPolicy,
Command: []string{"/bin/bash", "-c", command},
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse(initMasterContainerCPU),
corev1.ResourceMemory: resource.MustParse(initMasterContainerMemory),
corev1.ResourceEphemeralStorage: resource.MustParse(initMasterContainerStorage),
},
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse(initMasterContainerCPU),
corev1.ResourceMemory: resource.MustParse(initMasterContainerMemory),
corev1.ResourceEphemeralStorage: resource.MustParse(initMasterContainerStorage),
},
},
}
podTemplate := &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{container},
RestartPolicy: corev1.RestartPolicyNever,
},
}
if _, ok := job.Spec.ReplicaSpecs[ReplicaTypeJobMaster]; ok {
mainContainer := job.Spec.ReplicaSpecs[ReplicaTypeJobMaster].ReplicaSpec.Template.Spec.Containers[0]
if mainContainer.Image != "" {
podTemplate.Spec.Containers[0].Image = mainContainer.Image
}
if mainContainer.ImagePullPolicy != "" {
podTemplate.Spec.Containers[0].ImagePullPolicy = mainContainer.ImagePullPolicy
}
if len(mainContainer.Env) > 0 {
podTemplate.Spec.Containers[0].Env = append(
podTemplate.Spec.Containers[0].Env, mainContainer.Env...,
)
}
}
podIPEnv := corev1.EnvVar{
Name: envPodIP,
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
APIVersion: "v1",
FieldPath: "status.podIP",
},
},
}
podTemplate.Spec.Containers[0].Env = append(podTemplate.Spec.Containers[0].Env, podIPEnv)
job.Spec.ReplicaSpecs[ReplicaTypeJobMaster] = &elasticv1alpha1.ReplicaSpec{
ReplicaSpec: commonv1.ReplicaSpec{
Template: *podTemplate,
},
}
}
|
创建了类似下面启动命令的一个 Pod,然后设置了一些环境变量。
1
2
3
4
5
6
| - command:
- /bin/bash
- -c
- python -m dlrover.python.master.main --platform pyk8s --namespace dlrover --job_name torch-mnist-single-job-testing-1 --port 50001
image: registry.cn-hangzhou.aliyuncs.com/intell-ai/dlrover:master
imagePullPolicy: Always
|
由于重启策略是 Never
,也就是说如果 DLRover Master Pod 挂了,不会自动重启。
1.5 小结
DLRover Operator 非常轻量,没有核心的处理逻辑实现,主要是:
- 使用 CRD 描述 Job 任务、ScalePlan 扩容任务,实现字段、参数的转换
- 启动 DLRover Master Pod,让 DLRover 接管 Job 任务
2. DLRover Master
1
2
3
4
| def main():
args = parse_master_args()
exit_code = run(args)
return exit_code
|
--port
默认值 0,监听 master 的端口
--node_num
默认值 1,节点数量
--namespace
默认值 default
,创建 Pod 的命名空间
--platform
默认值 pyk8s
,平台类型,可选 pyk8s,k8s,ray,local
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| def run(args):
job_args = new_job_args(args.platform, args.job_name, args.namespace)
job_args.initilize()
logger.info("Job args : %s", job_args.to_json(indent=4))
_dlrover_context.config_master_port(port=args.port)
if job_args.platform == PlatformType.LOCAL:
from dlrover.python.master.local_master import LocalJobMaster
worker = job_args.node_args[NodeType.WORKER].group_resource
worker.count = args.node_num
master = LocalJobMaster(_dlrover_context.master_port, job_args)
else:
from dlrover.python.master.dist_master import DistributedJobMaster
update_context(job_args)
master = DistributedJobMaster(_dlrover_context.master_port, job_args)
master.prepare()
return master.run()
|
这里的关键就是 class DistributedJobMaster(JobMaster)
类。
Master 主要实现:
- 启动节点(例如,在 Kubernetes 上启动 Pod)
- 构建 rendezvous 训练节点集合
- 监控节点状态,在节点故障时启动新的节点进行恢复
- 收集每个节点的训练指标,包括吞吐量和工作负载
- 自动调节任务的节点数量,以加速训练并提高资源利用率
相关组件:
- JobManager,管理任务的节点。任务管理器可以启动节点、监控节点以及对节点进行扩容或缩容
- RendezvousManager,构建训练节点的集合
- TaskManager,分配数据分片任务给工作节点,并在工作节点故障时恢复数据分片任务
- MetricCollector,收集训练任务的指标
- ElasticPSService,管理参数服务器训练任务中存活的参数服务器节点
2.1 JobManager
JobManager 管理任务的节点。任务管理器可以启动节点、监控节点以及对节点进行扩容或缩容。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
critical_worker_index = get_critical_worker_index(args)
# Custom distribution strategy does not exit if there are pending nodes
wait_pending_relaunch = (
args.distribution_strategy == DistributionStrategy.CUSTOM
)
elastic_job = new_elastic_job(args.platform, args.job_name, args.namespace)
node_watcher = new_node_watcher(
args.platform, args.job_name, args.namespace
)
job_scaler = new_job_scaler(args.platform, args.job_name, args.namespace)
node_error_monitor = K8sJobErrorMonitor(
args.namespace, args.cordon_fault_node
)
return DistributedJobManager(
job_args=args,
critical_worker_index=critical_worker_index,
wait_pending_relaunch=wait_pending_relaunch,
speed_monitor=speed_monitor,
job=elastic_job,
node_watcher=node_watcher,
job_scaler=job_scaler,
error_monitor=node_error_monitor,
)
|
JobManager 中包含大量的操作句柄:
1
2
3
4
5
6
7
8
9
10
11
12
| def running_speed(self):
if len(self._global_step_records) < 2:
return 0
last_record = self._global_step_records[-1]
first_record = self._global_step_records[-2]
time_diff = last_record.timestamp - first_record.timestamp
if time_diff <= 0:
return 0
speed = (last_record.global_step - first_record.global_step) / (
time_diff
)
return speed
|
通过 dlrover-run
使用 gRPC 上报训练速度,调用 _collect_global_step
更新相关指标。
- 获取 Job 相关 Pod 名字、Service 地址的
elastic_job
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| class K8sElasticJob(ElasticJob):
def __init__(self, job_name, namespace):
self._k8s_client = k8sClient.singleton_instance(namespace)
self._namespace = namespace
self._job_name = job_name
def get_node_name(self, type, id):
return get_pod_name(self._job_name, type, id)
def get_node_service_addr(self, type, id):
service_name = get_pod_name(self._job_name, type, id)
return "%s.%s.svc:%d" % (
service_name,
self._namespace,
NODE_SERVICE_PORTS[type],
)
|
- 获取 Job 相关 Pod 列表、监听事件的
new_node_watcher
DLRover 的 Node 对应 Kubernetes 的 Pod。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| class PodWatcher(NodeWatcher):
def watch(self):
resource_version = None
pod_list = self._k8s_client.list_namespaced_pod(self._job_selector)
if pod_list:
resource_version = pod_list.metadata.resource_version
try:
stream = watch.Watch().stream(
self._k8s_client.client.list_namespaced_pod,
self._namespace,
label_selector=self._job_selector,
resource_version=resource_version,
timeout_seconds=60,
)
for event in stream:
node_event = _convert_pod_event_to_node_event(event)
if not node_event:
continue
yield node_event
except Exception as e:
def list(self) -> List[Node]:
nodes: List[Node] = []
pod_list = self._k8s_client.list_namespaced_pod(self._job_selector)
if not pod_list:
return nodes
if not pod_list.items:
return nodes
...
|
- 操作 ScalePlan 的
new_job_scaler
1
2
3
4
5
6
7
8
9
10
11
| class ElasticJobScaler(Scaler):
...
def scale(self, plan: ScalePlan):
scale_plan_crd = self._generate_scale_plan_crd(plan)
self._client.create_custom_resource(
group=ElasticJobApi.GROUP,
version=ElasticJobApi.VERION,
plural=ElasticJobApi.SCALEPLAN_PLURAL,
body=scale_plan_crd.to_dict(),
)
self._scaleplan_index += 1
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| def _handle_process_error(
self, node: Node, restart_count: int, error_data: str
):
if restart_count not in self._restart_errors:
self._restart_errors[restart_count] = error_data
logger.error(
f"{node.type}-{node.id} on {node.host_name} "
f"restart {restart_count} fails: {error_data}"
)
return False
def _handle_node_error(self, node: Node, error_data: str):
logger.info(
f"{node.name} on {node.host_name} is down. "
f"Reason: {error_data}"
)
if self.cordon_node_eanbled:
succeed = self._k8s_client.cordon_node(node.host_name)
if succeed:
logger.info(f"Node {node.name} is marked unschedulable.")
return True
|
2.2 RendezvousManager
RendezvousManager,构建训练节点的集合。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| class ElasticTrainingRendezvousManager(RendezvousManager):
def get_comm_world(
self, node_rank
) -> Tuple[int, int, Dict[int, NodeTopologyMeta]]:
"""如果一个集合点(rendezvous)轮次完成,则返回通信世界(communication world)。
当满足以下任一条件时,集合点完成:
1. 等待节点列表的大小等于最大节点数(max_nodes)。
2. 等待节点列表的大小大于最小节点数(min_nodes),且等于存活节点列表的大小。此外,在等待超时(waiting_timeout)期间,没有更多的工作节点加入集合点。
返回值:
- rdzv_round:轮次索引。
- group:组索引。
- world:类似于 {0: 8, 1: 8, 2: 8} 的字典,其中键是节点ID,值是节点的本地世界大小。
"""
...
|
2.3 TaskManager
TaskManager,分配数据分片任务给工作节点,并在工作节点故障时恢复数据分片任务。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| class TaskManager(object):
"""创建并分发任务,跟踪任务的生命周期。"""
def __init__(self, worker_restart_timeout: int, speed_monitor: SpeedMonitor):
"""
初始化 TaskManager,设置工作节点重启超时时间和速度监控器。
"""
def new_dataset(
self,
batch_size,
dataset_size,
dataset_name,
dataset_splitter: DatasetSplitter,
task_type=elastic_training_pb2.NONE,
):
"""
创建一个新数据集,并初始化任务管理。
"""
def get_dataset_task(self, node_type, node_id, dataset_name):
"""
获取指定数据集、节点类型和节点 ID 的下一个任务。
"""
def get_dataset(self, dataset_name):
"""
根据数据集名称获取数据集。
"""
...
|
2.4 MetricCollector
MetricCollector,收集训练任务的指标。
1
2
3
4
5
6
7
8
9
10
11
12
| def _create_metric_collector_if_needed(self, params: JobArgs):
if not params.enable_dynamic_sharding:
return None
job_uuid = params.job_uuid
reporter = ReporterType.LOCAL
if params.optimize_mode == OptimizeMode.CLUSTER:
reporter = ReporterType.DLROVER_BRAIN
collector = JobMetricCollector(
job_uuid, params.namespace, params.cluster, params.user, reporter
)
collector.collect_job_type(params.distribution_strategy)
return collector
|
这里有一个判断,params.optimize_mode
为 cluster
时,上报的数据会由 BrainReporter 上报存储到 MySQL,否则会由 LocalReporter 上报存储到 DLRover Master 的内存中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| class JobMetricCollector(BaseMetricCollector):
def collect_dataset_metric(self, name, size, ds_type=DatasetType.TEXT):
pass
def def collect_training_hyper_params(self, epoch, batch_size):
pass
def collect_job_type(self, job_type):
pass
def collect_model_metric(self, model_info: ModelInfo):
pass
def _report_runtime_stats(self):
pass
def collect_custom_data(self, metric_dict=None):
pass
def collect_runtime_stats(
self, speed_monitor: SpeedMonitor, running_nodes: List[Node]
):
pass
def report_runtime_stats_periodically(self):
pass
def collect_job_exit_reason(self, reason):
pass
|
另一方面,在 MasterServicer
中也会接收来自 Agent 使用 gRPC 上报的数据。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
| class MasterServicer(elastic_training_pb2_grpc.MasterServicer):
def report(self, request, _):
message = grpc.deserialize_message(request.data)
if isinstance(message, grpc.DatasetShardParams):
success = self._collect_dataset_shard_params(message)
elif isinstance(message, grpc.ResourceStats):
success = self._update_node_resource_usage(
node_type, node_id, message
)
elif isinstance(message, grpc.ModelInfo):
success = self._collect_model_info(message)
elif isinstance(message, grpc.GlobalStep):
success = self._collect_global_step(message)
elif isinstance(message, grpc.ShardCheckpoint):
success = self._restore_shard_checkpoint(message)
elif isinstance(message, grpc.TaskResult):
success = self._report_task_result(message)
elif isinstance(message, grpc.ClusterVersion):
success = self._update_cluster_version(message)
elif isinstance(message, grpc.NodeAddress):
success = self._update_node_address(message)
elif isinstance(message, grpc.NetworkStatus):
success = self._update_node_status(message)
elif isinstance(message, grpc.NodeEvent):
success = self._update_node_event(message)
elif isinstance(message, grpc.SyncJoin):
success = self._join_sync(node_type, node_id, message)
elif isinstance(message, grpc.SyncFinish):
success = self._sync_finished(message)
elif isinstance(message, grpc.SyncBarrier):
...
|
2.5 ElasticPSService
ElasticPSService 管理 Parameter Server 训练任务中的参数节点。
1
2
3
4
| def _create_elastic_ps_service_if_needed(params: JobArgs):
if params.distribution_strategy == DistributionStrategy.PS:
return ElasticPsService()
return None
|
仅对 Parameter Server 任务有效,主要是管理 PS 任务的版本。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
| class ElasticPsService(object):
def __init__(self):
self._global_version = 0
self._ps_local_version = {}
self._worker_local_version = {}
self._worker_restored_version = {}
def inc_global_cluster_version(self):
"""
增加全局集群版本号
"""
pass
def get_ps_version(self, version_type, ps_id):
"""
获取参数服务器(PS)的版本
参数:
version_type: 版本类型(全局或本地)
ps_id: 参数服务器的ID
"""
pass
def update_ps_version(self, ps_id, version_type, version):
"""
更新参数服务器(PS)的版本
参数:
ps_id: 参数服务器的ID
version_type: 版本类型(全局或本地)
version: 要设置的版本号
"""
pass
def get_worker_version(self, version_type, worker_id):
"""
获取工作节点的版本
参数:
version_type: 版本类型(全局、本地或恢复的版本)
worker_id: 工作节点的ID
"""
pass
def update_worker_version(self, worker_id, version_type, version):
"""
更新工作节点的版本
参数:
worker_id: 工作节点的ID
version_type: 版本类型(全局、本地或恢复的版本)
version: 要设置的版本号
"""
pass
|
2.6 运行 prepare 启动 gRPC 和本地 Manager 进程
1
2
3
4
5
6
7
| def prepare(self):
# 启动 Master 上的 RPC 服务,以供与 Worker 节点通信
self._master_server.start()
if self.task_manager:
self.task_manager.start()
if self.job_manager:
self.job_manager.start()
|
在运行之前,还有需要在 Master 启动 RPC 服务,启动 TaskManager 和 JobManager。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
| def start(self):
if self._worker_restart_timeout > 0:
threading.Thread(
target=self._check_and_reassign_timeout_tasks,
name="check_timeout_tasks",
daemon=True,
).start()
def _check_and_reassign_timeout_tasks(self):
"""Check whether there are timeout tasks periodically."""
logger.info("Start the thread to monitor timeout tasks.")
while True:
for _, dataset in self._datasets.items():
# Copy doing task list because the doing list will pop items
# in the following loop.
doing_tasks = dataset.doing.copy()
cur = time.time()
for task_id, doing_task in doing_tasks.items():
start = self._worker_start_task_time.get(
doing_task.node_id, cur
)
if (
doing_task.task.task_type
== elastic_training_pb2.EVALUATION
and cur - start
> max(
_TASK_TIMEOUT_THRESHOLD_SECS,
self._worker_restart_timeout,
)
):
logger.info(
f"The task {task_id} of {doing_task.node_type}-"
f"{doing_task.node_id} is timeout."
)
dataset.report_task_status(task_id, success=False)
self._invoke_task_timeout_callback(doing_task.node_id)
break
time.sleep(30)
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
| def start(self):
self._scaler.start()
self._job_optimizer.update_job_uuid(self._job_args.job_uuid)
self._job_optimizer.init_job_resource(self._job_resource)
self._adjust_worker_for_estimator()
self._init_nodes()
self._init_job_auto_scaler()
plan = self._create_initial_scale_plan()
if not self._has_running_workers():
# The the job relaunches the evicted master, there are alive
# worker nodes and the master does not need to launch workers.
self._scaler.scale(plan)
else:
logger.info(
"The recovered master skips launching workers at begining."
)
worker_num = 0
if NodeType.WORKER in plan.node_group_resources:
worker_num = plan.node_group_resources[NodeType.WORKER].count
if NodeType.CHIEF in plan.node_group_resources:
worker_num += plan.node_group_resources[NodeType.CHIEF].count
self._speed_monitor.set_target_worker_num(worker_num)
threading.Thread(
target=self._monitor_nodes, name="node_monitor", daemon=True
).start()
threading.Thread(
target=self._monitor_node_heart_beat,
name="node_heart_beat_monitor",
daemon=True,
).start()
if os.getenv("KUBERNETES_SERVICE_HOST"):
threading.Thread(
target=self._monitor_scale_plan_crd,
name="scaleplan_monitor",
daemon=True,
).start()
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| def _monitor_nodes(self):
logger.info("Start monitoring nodes events.")
while True:
try:
nodes = self._node_watcher.list()
self._process_list_nodes(nodes)
if self._stop_monitor:
logger.info("Stop processing node events")
break
# watch pod 的状态,并封装为 NodeEvent,给 _process_event 统一处理
for event in self._node_watcher.watch():
try:
self._process_event(event)
except Exception as e:
logger.warning(e)
detail_trace_back = traceback.format_exc()
logger.warning(detail_trace_back)
except Exception as e:
logger.warning(e)
time.sleep(30)
time.sleep(5)
|
处理节点的心跳事件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| def _monitor_node_heart_beat(self):
logger.info("Start monitoring the heart beat of nodes.")
while True:
with self._lock:
events = self._get_dead_node_event()
# 超过 300s 没有响应,则认为节点异常
for event in events:
try:
self._process_event(event)
except Exception as e:
logger.warning(e)
detail_trace_back = traceback.format_exc()
logger.warning(detail_trace_back)
time.sleep(15)
|
这里的 _get_dead_node_event
就是获取异常的 Pod 的事件。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
| def _get_dead_node_event(self, window_interval=300) -> List[NodeEvent]:
now = time.time()
dead_events = []
for _, nodes in self._job_nodes.items():
for _, node in nodes.items():
if (
node.heartbeat_time > 0
and now - node.heartbeat_time > window_interval
and node.status == NodeStatus.RUNNING
):
event_node = copy.deepcopy(node)
event_node.status = NodeStatus.FAILED
event_node.exit_reason = NodeExitReason.NO_HEARTBEAT
event = NodeEvent(
event_type=NodeEventType.DELETED,
node=event_node,
)
dead_events.append(event)
error_data = (
f"No heartbeat for over {window_interval} seconds."
)
self._error_monitor.process_error(
node,
node.relaunch_count,
error_data,
TrainingExceptionLevel.NODE_ERROR,
)
logger.warning(
f"The node {node.name} has not sent a heartbeat "
f"for over {window_interval} seconds."
)
return dead_events
|
这里可以看到最核心的是调用了 _process_event
以及 _process_node_events
函数。
2.7 异常处理逻辑
1
2
3
4
5
6
7
| def _process_event(self, event: NodeEvent):
with self._lock:
should_relaunch = self._should_relaunch(
cur_node, status_change_flow
)
if should_relaunch:
self._relaunch_node(cur_node)
|
第一步,判断是否需要重启
需要注意,这里 Node 的 exit_reason
是对 Pod 的封装和转换,而不是直接代表 Pod 的状态。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
| def _should_relaunch(self, node: Node, status_change_flow: NodeStateFlow):
should_relaunch = (
status_change_flow.should_relaunch
and self._enable_relaunch_node
and node.relaunchable
)
if should_relaunch:
# 排除一些特殊情况,Error、OOM、超过最大重启次数、Killed
if (
node.exit_reason == NodeExitReason.FATAL_ERROR
and not _dlrover_context.relaunch_always
):
should_relaunch = False
elif node.exit_reason == NodeExitReason.OOM:
mem = node.config_resource.memory
if mem >= NodeResourceLimit.MAX_MEMORY:
should_relaunch = False
logger.warning(
"The memory of worker %s is beyond the limit %s MB.",
mem,
NodeResourceLimit.MAX_MEMORY,
)
elif node.relaunch_count >= node.max_relaunch_count:
should_relaunch = False
logger.warning(
"The relaunched count %s is beyond the maximum %s.",
node.relaunch_count,
node.max_relaunch_count,
)
else:
node.is_recovered_oom = True
self._job_optimizer.adjust_oom_resource(node)
elif node.exit_reason != NodeExitReason.KILLED:
if node.relaunch_count >= node.max_relaunch_count:
logger.warning(
"The relaunch count "
f"{node.relaunch_count}/{node.max_relaunch_count} "
"has been exhausted."
)
should_relaunch = False
return should_relaunch
|
第二步,重启 Node 节点,即 Pod
在 AllReduce 策略下,就是创建 Worker 节点。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
| def _relaunch_node(self, node: Node):
if node.type == NodeType.WORKER:
plan = self._worker_manager.relaunch_node(
node, self._remove_exited_node
)
elif node.type == NodeType.PS:
plan = self._ps_manager.relaunch_node(
node, self._remove_exited_node
)
elif node.type == NodeType.EVALUATOR:
plan = self._evaluator_manager.relaunch_node(
node, self._remove_exited_node
)
elif node.type == NodeType.CHIEF or node.type == NodeType.MASTER:
plan = self._chief_manager.relaunch_node(
node, self._remove_exited_node
)
else:
logger.error("Not support node type %s", node.type)
self._set_ps_addrs_in_plan(plan)
if self._remove_exited_node:
plan.remove_nodes.append(node)
node.relaunchable = False # Avoid repeatedly relaunching the node.
self._scaler.scale(plan)
|
在创建一个 ScalePlan 时,需要获取旧节点的 rank_index
、service_addr
等信息,用于创建新的 Pod 节点。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| def relaunch_node(self, node: Node, remove_exited_node=False):
plan = ScalePlan()
with self._lock:
new_id = next(self._node_id_iter)
relaunch_node = node.get_relaunch_node_info(new_id)
self._nodes[new_id] = relaunch_node
logger.info("Relaunch node %s to %s", node.name, new_id)
plan.launch_nodes.append(
Node(
node.type,
new_id,
copy.deepcopy(relaunch_node.config_resource),
rank_index=node.rank_index,
name=self._new_node_name_fn(node.type, new_id),
service_addr=node.service_addr,
relaunch_count=relaunch_node.relaunch_count,
)
)
if remove_exited_node and not node.is_released and node.exited():
node.is_released = True
plan.remove_nodes.append(node)
return plan
|
ScalePlan
不会提交为 CR
对象给 K8s,而是根据不同运行时给了不同的 _scaler
。
1
2
3
4
5
6
7
8
9
10
11
| class PodScaler(Scaler):
def scale(self, plan: ScalePlan):
with self._lock:
for type, group_resource in plan.node_group_resources.items():
if group_resource.count > len(cur_pods):
self._scale_up_pods(type, plan, cur_pods, max_pod_id)
elif group_resource.count < len(cur_pods):
self._scale_down_pods(type, plan, cur_pods)
for node in plan.launch_nodes:
self._create_node_queue.append(node)
self._update_job_pods(job_pods)
|
在 K8s 下 Worker 节点都是 PodScaler 创建出来的。 _scale_up_pods
会将创建的 Pod 节点加入到 _create_node_queue
中。
2.8 运行 run 启动任务
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
| def run(self):
"""
The main loop of master.
Dispatch the tasks to the workers until all the tasks are completed.
"""
try:
while True:
if self._stop_requested:
break
msg = self.job_manager.early_stop()
if msg:
self.request_stop(False, msg)
continue
self.job_manager.clear_exited_nodes()
if self.job_manager and self.job_manager.all_workers_exited():
if self.job_manager.pend_without_workers():
time.sleep(30)
continue
if self.job_manager.all_workers_failed():
logger.error("All workers failed")
self._exit_code = 1
self._exit_reason = JobExitReason.UNKNOWN_ERROR
elif (
self.task_manager and not self.task_manager.finished()
):
logger.warning(
"All workers exited but there also are "
"unfinished tasks",
)
break
if (
self.job_manager.all_running_node_hanged()
and self.task_manager.task_hanged()
):
logger.error("All nodes hangeds")
self._exit_code = 1
self._exit_reason = JobExitReason.HANG_ERROR
if (
self.task_manager
and self.task_manager.finished()
and (
not self.job_manager
or self.job_manager.all_critical_node_completed()
)
):
logger.info("All task completed")
break
time.sleep(30)
|
master 会执行一个死循环,每隔 30s 检测一次状态。具有如下状态时,master 会退出:
- 收到停止请求
self._stop_requested
- 所有 worker 已退出
self.job_manager.all_workers_exited()
task_manager
已完成 self.task_manager.finished()
2.9 小结
DLRover 容错的逻辑主要在 Master 中,而其中的关键在 JobManager。
监控的数据源有如下几类:
- Agent 上报的数据,包括指标、训练速度等
- Master 从 K8s 获取的 Pod 事件数据
JobManager 基于这些上报的数据,封装为 NodeEvent 对象,然后统一由 _process_event
处理。
3. DLRover Trainer
3.1 启动脚本
1
2
3
4
5
6
7
8
9
10
| cat /usr/local/bin/dlrover-run
#!/usr/local/bin/python
# -*- coding: utf-8 -*-
import re
import sys
from dlrover.trainer.torch.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())
|
3.2 入口函数
dlrover.trainer.torch.main
调用的是 dlrover.trainer.torch.elastic_run.main
1
2
3
4
| @record
def main(args=None):
args = parse_args(args)
run(args)
|
包了一层入口而已,没有实际逻辑。
3.3 参数解析
--network-check
,在训练之前,先检查网络状态。
--node_unit
,设置节点单元数量,调度的节点数应为此数量的倍数。
--auto_config
,自动配置节点和每个节点的进程数。
--auto_tunning
,自动调整并行配置。
--exclude-straggler
,排除落后节点,仅在 network_check
开启时有效。
--save_at_breakpoint
,训练失败时保存检查点到内存。
--accelerator
,设置机器的加速器类型,如 nvidia.com/gpu
或 ascend-npu
。
3.4 启动训练任务
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
| def run(args):
# 连接 DLRover Master
dlrover_master_ready = grpc.addr_connected(master_addr)
_, max_nodes = parse_min_max_nnodes(args.nnodes)
# 如果没有就绪,并且 `node_rank == 0` 就将当前节点作为 DLRover Master 启动
if not dlrover_master_ready and node_rank == 0:
# Only start the dlrover master on the rank-0 node.
master_handler, master_addr = _launch_dlrover_local_master(
master_addr,
job_name,
max_nodes,
)
logger.info(f"Set the dlrover master addr as {master_addr}")
os.environ[NodeEnv.DLROVER_MASTER_ADDR] = master_addr
use_dlrover_launch = _check_to_use_dlrover_run(master_addr, max_nodes)
if args.standalone and not use_dlrover_launch:
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:29400"
args.rdzv_id = str(uuid.uuid4())
logger.info(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv-backend={args.rdzv_backend} "
f"--rdzv-endpoint={args.rdzv_endpoint} "
f"--rdzv-id={args.rdzv_id}\n"
f"**************************************\n"
)
# 解析训练参数
config, cmd, cmd_args = _elastic_config_from_args(args)
config.run_id = job_name
config.role = "dlrover-trainer"
try:
# 启动训练
elastic_launch(
config=config,
entrypoint=cmd,
use_dlrover_launch=use_dlrover_launch,
)(*cmd_args)
finally:
if master_handler:
master_handler.close()
|
1
2
3
4
5
6
7
8
| class elastic_launch:
def __call__(self, *args):
if self._use_dlrover_launch:
return launch_agent(self._config, self._entrypoint, list(args))
else:
return torch_launch_agent(
self._config, self._entrypoint, list(args)
)
|
如果没有使用 dlrover
托管就直接使用 from torch.distributed.launcher.api import launch_agent as torch_launch_agent
启动训练。否则使用下面的函数启动训练:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| def launch_agent(
config: ElasticLaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
# 生成唯一的 `run_id`
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning(
f"config has no run_id, generated a random run_id: {run_id}"
)
config.run_id = run_id
# 初始化监控
monitor = TorchTrainingMonitor(ConfigPath.RUNTIME_METRICS)
monitor.start()
# 初始化 Agent
...
agent = ElasticTrainingAgent(
node_rank=node_rank,
config=config,
entrypoint=entrypoint,
spec=spec,
start_method=config.start_method,
log_dir=config.log_dir,
)
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
# 启动 agent
result = agent.run()
...
|
1
2
| class ElasticTrainingAgent(LocalElasticAgent):
...
|
可以看到 DLRover 使用了 PyTorch 内置的 LocalElasticAgent 负责管理节点上的训练进程。
3.5 数据上报
从上面可以看到,在启动 Agent 时,会启动一个监控任务的进程。
1
2
3
4
5
6
7
| def launch_agent(
config: ElasticLaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
monitor = TorchTrainingMonitor(ConfigPath.RUNTIME_METRICS)
monitor.start()
|
在这个监控进程中,会周期性地上报节点的资源使用情况。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| class TorchTrainingMonitor(Singleton):
def start(self):
if os.getenv(NodeEnv.MONITOR_ENABLED, "false") != "true":
return
self._resource_monitor.start()
thread = threading.Thread(
target=self._periodically_report,
name="report-step",
daemon=True,
)
thread.start()
def _periodically_report(self):
while True:
if self._group_rank == 0:
self.report_resource_with_step()
self.send_heartbeat()
time.sleep(15)
|
有两个上报数据的流程:
第一种是,在 Pod 中使用 psutil
获取资源使用情况。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| class ResourceMonitor(Singleton):
def report_resource(self):
try:
used_mem = get_used_memory()
cpu_percent = get_process_cpu_percent()
if self._gpu_enabled:
self._gpu_stats = get_gpu_stats()
current_cpu = round(cpu_percent * self._total_cpu, 2)
self._master_client.report_used_resource(
used_mem, current_cpu, self._gpu_stats
)
logger.debug(
"Report Resource CPU : %s, Memory %s, GPU %s",
current_cpu,
used_mem,
self._gpu_stats,
)
except Exception as e:
logger.exception(e)
|
第二种是,从 /tmp/dlrover/runtime_metrics.json
文件中获取,训练速度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| def report_resource_with_step(self):
if self._group_rank != 0:
return
try:
if not os.path.exists(self._metrics_path):
return
with open(self._metrics_path, "r") as f:
record = json.load(f)
step = record.get("step", 0)
timestamp = record.get("timestamp", 0)
if step > 0 and timestamp - self._last_timestamp > 15:
self._resource_monitor.report_resource()
self._last_timestamp = timestamp
self._master_client.report_global_step(
step,
self._last_timestamp,
)
except Exception as e:
logger.warning(e)
|
ElasticTrainer 会从记录的梯度状态中,将 num_steps
和 timestamp
写入到约定的指标文件。
1
2
3
4
5
6
7
8
9
| class ElasticTrainer(object):
def report_training_step(self):
timestamp = time.time()
record = TrainingRecord(self.gradient_state.num_steps, timestamp)
metric_path = os.getenv(ConfigPath.ENV_RUNTIME_METRICS, "")
rank = get_rank()
if os.path.exists(os.path.dirname(metric_path)) and rank == 0:
with open(metric_path, "w") as f:
f.write(record.to_json(indent=4))
|
3.6 小结
Trainer 主要有两个功能:
- 使用
LocalElasticAgent
管理节点上的训练进程 - 使用 gRPC 进行数据的上报,包括训练速度、资源使用情况
4. 总结
本篇分析了 DLRover 在 Kubernetes 下的实现细节,主要涉及 AllReduce 策略下的训练任务,跳过了 PS 任务以及 Brain 相关内容:
- DLRover Operator 定义了作业和扩容相关的字段,维护相关状态。只是启动 DLRover Master,没有弹性、容错相关的逻辑实现
- 每个训练任务都会启动一个 DLRover Master,掌控着整个训练节奏,其中:
- JobMnager 用于管理作业的启停、容错、扩缩容
- RenderzerManager 用于节点的组网
- TaskManager 用于管理数据分片
- MetricCollector 用于收集训练指标
- ElasticPSService 用于管理 PS 任务中的参数节点
- DLRover 在处理异常时,会将检测状态封装为 NodeEvent,通过 DLRover Master 中的
_process_event
来统一处理 - 使用
dlrover-run
脚本启动训练任务时,DLRover 会使用 Pytorch 中的 LocalElasticAgent
管理节点上的训练进程; 同时启动一个监控进程,将训练相关的指标通过 gRPC 上报给 DLRover Master