Please enable Javascript to view the contents

Kubernetes 下的 DLRover 工作流程分析

 ·  ☕ 13 分钟

本文使用的 DLRover 版本是 0.3.7

1. DLRover Operator

1.1 启动 ElasticJob 和 ScalePlan 的控制器

实现代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
    // 创建 ElasticJob 的控制器
    if err = controllers.NewElasticJobReconciler(mgr, masterImage).SetupWithManager(mgr); err != nil {
        setupLog.Error(err, "unable to create controller", "controller", "ElasticJob")
        os.Exit(1)
    }
    // 创建 ScalePlan 的控制器
    if err = controllers.NewScalePlanReconciler(mgr).SetupWithManager(mgr); err != nil {
        setupLog.Error(err, "unable to create controller", "controller", "ScalePlan")
        os.Exit(1)
    }
    // 启动控制器
    if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
        setupLog.Error(err, "problem running manager")
        os.Exit(1)
    }

这部分代码是使用 Kubebuilder 生成 Operator 框架时,自动生成的。

1.2 ElasticJob 控制器

实现代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    switch job.Status.Phase {
    case "", commonv1.JobCreated:
        // 创建一个 Master Pod
        r.initializeJob(job)
        err := r.createEasydlMaster(job)
        if err != nil {
            logger.Warningf("Fail to create EasyDL Master")
            return ctrl.Result{RequeueAfter: defaultPollInterval}, err
        }
        r.syncJobStateByReplicas(job)
    case commonv1.JobPending:
        r.syncJobStateByReplicas(job)
    case commonv1.JobRunning:
        r.handleFaultPods(job)
        r.syncJobStateByReplicas(job)
    case commonv1.JobScaling:
        scalePlan, err := r.getJobScalePlan(job)
        if err != nil {
            logger.Errorf("Job %s: Fail to get scaleplan: %s", job.Name, err)
        }
        if scalePlan.Status.Phase != commonv1.JobPending {
            logger.Infof("Job %s: Skip a %s scaleplan %s.", job.Name, scalePlan.Status.Phase, scalePlan.Name)
            return ctrl.Result{}, nil
        }
        r.updateScalePlanScaling(scalePlan)
        if scalePlan != nil {
            err := r.executeScaling(job, scalePlan)
            if err != nil {
                logger.Errorf("Job %s: Fail to execute scaleplan %s: %s", job.Name, scalePlan.Name, err)
            }
        }
        r.syncJobStateByReplicas(job)
    case commonv1.JobSucceeded:
        r.syncJobStateByReplicas(job)
        r.stopRunningPods(job)
    case commonv1.JobFailed:
        logger.Infof("Job %s failed", job.Name)
        r.syncJobStateByReplicas(job)
        r.stopRunningPods(job)
    default:
        logger.Warningf("job %s unknown status %s", job.Name, job.Status.Phase)
    }
    return ctrl.Result{}, nil

虽然有很多的 case 判断,但主要在做两件事:

  1. 初始化时,创建 DLRover Master Pod
  2. 同步状态,将 ElasticJob 的状态同步到 ScalePlan、Pod 上

1.3 ScalePlan 控制器

代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
func (r *ScalePlanReconciler) updateJobToScaling(
    scalePlan *elasticv1alpha1.ScalePlan,
    job *elasticv1alpha1.ElasticJob,
    pollInterval time.Duration) (ctrl.Result, error) {
    if scalePlan.Status.Phase != commonv1.JobCreated && scalePlan.Status.Phase != commonv1.JobPending {
        logger.Infof("Skip a %s ScalePlan %s", scalePlan.Status.Phase, scalePlan.Name)
        return ctrl.Result{}, nil
    }
    job.Status.ScalePlan = scalePlan.Name
    for taskType, resourceSpec := range scalePlan.Spec.ReplicaResourceSpecs {
        if job.Status.ReplicaStatuses[taskType].Initial == 0 {
            job.Status.ReplicaStatuses[taskType].Initial = int32(resourceSpec.Replicas)
        }
    }
    msg := fmt.Sprintf("ElasticJob %s is scaling by %s with status %s.", job.Name, scalePlan.Name, scalePlan.Status.Phase)
    logger.Infof(msg)
    if scalePlan.Status.Phase == commonv1.JobCreated {
        scalePlan.Status.Phase = commonv1.JobPending
        err := updateScalePlanStatus(r.Client, scalePlan)
        if err != nil {
            return ctrl.Result{RequeueAfter: pollInterval}, err
        }
    }

    common.UpdateStatus(&job.Status, commonv1.JobScaling, common.JobScalingReason, msg)
    err := updateElasticJobStatus(r.Client, job)
    if err != nil {
        logger.Errorf("Failed to update job %s status to scaling with %s, err: %v", job.Name, scalePlan.Name, err)
        return ctrl.Result{RequeueAfter: pollInterval}, err
    }

    return ctrl.Result{}, nil
}

主要的逻辑是:

  1. 将 scalePlan 关联到 ElasticJob 对象的 Status 中
  2. 更新 ScalePlan 和 ElasticJob 的状态

1.4 DLRover Master 的启动入口

从前面的代码可以看到 Operator 除了创建出一个 Master Pod 之外,主要是各种状态同步和关联,根本没有我们想要的容错逻辑。

下面是创建 DLRover Master Pod 的模板配置:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
func NewMasterTemplateToJob(job *elasticv1alpha1.ElasticJob, masterImage string) {
    command := masterCommand + fmt.Sprintf(
        " --platform pyk8s --namespace %s --job_name %s --port %d",
        job.Namespace, job.Name, masterServicePort,
    )
    container := corev1.Container{
        Name:            "main",
        Image:           masterImage,
        ImagePullPolicy: defaultImagePullPolicy,
        Command:         []string{"/bin/bash", "-c", command},
        Resources: corev1.ResourceRequirements{
            Requests: corev1.ResourceList{
                corev1.ResourceCPU:              resource.MustParse(initMasterContainerCPU),
                corev1.ResourceMemory:           resource.MustParse(initMasterContainerMemory),
                corev1.ResourceEphemeralStorage: resource.MustParse(initMasterContainerStorage),
            },
            Limits: corev1.ResourceList{
                corev1.ResourceCPU:              resource.MustParse(initMasterContainerCPU),
                corev1.ResourceMemory:           resource.MustParse(initMasterContainerMemory),
                corev1.ResourceEphemeralStorage: resource.MustParse(initMasterContainerStorage),
            },
        },
    }
    podTemplate := &corev1.PodTemplateSpec{
        Spec: corev1.PodSpec{
            Containers:    []corev1.Container{container},
            RestartPolicy: corev1.RestartPolicyNever,
        },
    }
    if _, ok := job.Spec.ReplicaSpecs[ReplicaTypeJobMaster]; ok {
        mainContainer := job.Spec.ReplicaSpecs[ReplicaTypeJobMaster].ReplicaSpec.Template.Spec.Containers[0]
        if mainContainer.Image != "" {
            podTemplate.Spec.Containers[0].Image = mainContainer.Image
        }
        if mainContainer.ImagePullPolicy != "" {
            podTemplate.Spec.Containers[0].ImagePullPolicy = mainContainer.ImagePullPolicy
        }
        if len(mainContainer.Env) > 0 {
            podTemplate.Spec.Containers[0].Env = append(
                podTemplate.Spec.Containers[0].Env, mainContainer.Env...,
            )
        }
    }
    podIPEnv := corev1.EnvVar{
        Name: envPodIP,
        ValueFrom: &corev1.EnvVarSource{
            FieldRef: &corev1.ObjectFieldSelector{
                APIVersion: "v1",
                FieldPath:  "status.podIP",
            },
        },
    }
    podTemplate.Spec.Containers[0].Env = append(podTemplate.Spec.Containers[0].Env, podIPEnv)
    job.Spec.ReplicaSpecs[ReplicaTypeJobMaster] = &elasticv1alpha1.ReplicaSpec{
        ReplicaSpec: commonv1.ReplicaSpec{
            Template: *podTemplate,
        },
    }
}

创建了类似下面启动命令的一个 Pod,然后设置了一些环境变量。

1
2
3
4
5
6
  - command:
    - /bin/bash
    - -c
    - python -m dlrover.python.master.main --platform pyk8s --namespace dlrover --job_name torch-mnist-single-job-testing-1 --port 50001
    image: registry.cn-beijing.aliyuncs.com/intell-ai/dlrover:master
    imagePullPolicy: Always

由于重启策略是 Never,也就是说如果 DLRover Master Pod 挂了,不会自动重启。

1.5 小结

DLRover Operator 非常轻量,没有核心的处理逻辑实现,主要是:

  1. 使用 CRD 描述 Job 任务、ScalePlan 扩容任务,实现字段、参数的转换
  2. 启动 DLRover Master Pod,让 DLRover 接管 Job 任务

2. DLRover Master

  • 执行入口
1
2
3
4
def main():
    args = parse_master_args()
    exit_code = run(args)
    return exit_code
  • 参数解析

--port 默认值 0,监听 master 的端口
--node_num 默认值 1,节点数量
--namespace 默认值 default ,创建 Pod 的命名空间
--platform 默认值 pyk8s,平台类型,可选 pyk8s,k8s,ray,local

  • 启动逻辑
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def run(args):
    job_args = new_job_args(args.platform, args.job_name, args.namespace)
    job_args.initilize()
    logger.info("Job args : %s", job_args.to_json(indent=4))
    _dlrover_context.config_master_port(port=args.port)
    if job_args.platform == PlatformType.LOCAL:
        from dlrover.python.master.local_master import LocalJobMaster

        worker = job_args.node_args[NodeType.WORKER].group_resource
        worker.count = args.node_num
        master = LocalJobMaster(_dlrover_context.master_port, job_args)
    else:
        from dlrover.python.master.dist_master import DistributedJobMaster

        update_context(job_args)
        master = DistributedJobMaster(_dlrover_context.master_port, job_args)
    master.prepare()
    return master.run()

这里的关键就是 class DistributedJobMaster(JobMaster) 类。

Master 主要实现:

  • 启动节点(例如,在 Kubernetes 上启动 Pod)
  • 构建 rendezvous 训练节点集合
  • 监控节点状态,在节点故障时启动新的节点进行恢复
  • 收集每个节点的训练指标,包括吞吐量和工作负载
  • 自动调节任务的节点数量,以加速训练并提高资源利用率

相关组件:

  • JobManager,管理任务的节点。任务管理器可以启动节点、监控节点以及对节点进行扩容或缩容
  • RendezvousManager,构建训练节点的集合
  • TaskManager,分配数据分片任务给工作节点,并在工作节点故障时恢复数据分片任务
  • MetricCollector,收集训练任务的指标
  • ElasticPSService,管理参数服务器训练任务中存活的参数服务器节点

2.1 JobManager

JobManager 管理任务的节点。任务管理器可以启动节点、监控节点以及对节点进行扩容或缩容。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
    critical_worker_index = get_critical_worker_index(args)
    # Custom distribution strategy does not exit if there are pending nodes
    wait_pending_relaunch = (
        args.distribution_strategy == DistributionStrategy.CUSTOM
    )

    elastic_job = new_elastic_job(args.platform, args.job_name, args.namespace)
    node_watcher = new_node_watcher(
        args.platform, args.job_name, args.namespace
    )
    job_scaler = new_job_scaler(args.platform, args.job_name, args.namespace)
    node_error_monitor = K8sJobErrorMonitor(
        args.namespace, args.cordon_fault_node
    )

    return DistributedJobManager(
        job_args=args,
        critical_worker_index=critical_worker_index,
        wait_pending_relaunch=wait_pending_relaunch,
        speed_monitor=speed_monitor,
        job=elastic_job,
        node_watcher=node_watcher,
        job_scaler=job_scaler,
        error_monitor=node_error_monitor,
    )

JobManager 中包含大量的操作句柄:

  • 监控训练速度的 speed_monitor
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
    def running_speed(self):
        if len(self._global_step_records) < 2:
            return 0
        last_record = self._global_step_records[-1]
        first_record = self._global_step_records[-2]
        time_diff = last_record.timestamp - first_record.timestamp
        if time_diff <= 0:
            return 0
        speed = (last_record.global_step - first_record.global_step) / (
            time_diff
        )
        return speed

通过 dlrover-run 使用 gRPC 上报训练速度,调用 _collect_global_step 更新相关指标。

  • 获取 Job 相关 Pod 名字、Service 地址的 elastic_job
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class K8sElasticJob(ElasticJob):
    def __init__(self, job_name, namespace):
        self._k8s_client = k8sClient.singleton_instance(namespace)
        self._namespace = namespace
        self._job_name = job_name

    def get_node_name(self, type, id):
        return get_pod_name(self._job_name, type, id)

    def get_node_service_addr(self, type, id):
        service_name = get_pod_name(self._job_name, type, id)
        return "%s.%s.svc:%d" % (
            service_name,
            self._namespace,
            NODE_SERVICE_PORTS[type],
        )
  • 获取 Job 相关 Pod 列表、监听事件的 new_node_watcher

DLRover 的 Node 对应 Kubernetes 的 Pod。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class PodWatcher(NodeWatcher):
    def watch(self):
        resource_version = None
        pod_list = self._k8s_client.list_namespaced_pod(self._job_selector)
        if pod_list:
            resource_version = pod_list.metadata.resource_version
        try:
            stream = watch.Watch().stream(
                self._k8s_client.client.list_namespaced_pod,
                self._namespace,
                label_selector=self._job_selector,
                resource_version=resource_version,
                timeout_seconds=60,
            )
            for event in stream:
                node_event = _convert_pod_event_to_node_event(event)
                if not node_event:
                    continue
                yield node_event
        except Exception as e:
    def list(self) -> List[Node]:
        nodes: List[Node] = []
        pod_list = self._k8s_client.list_namespaced_pod(self._job_selector)
        if not pod_list:
            return nodes
        if not pod_list.items:
            return nodes
        ...
  • 操作 ScalePlan 的 new_job_scaler
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class ElasticJobScaler(Scaler):
    ...
    def scale(self, plan: ScalePlan):
        scale_plan_crd = self._generate_scale_plan_crd(plan)
        self._client.create_custom_resource(
            group=ElasticJobApi.GROUP,
            version=ElasticJobApi.VERION,
            plural=ElasticJobApi.SCALEPLAN_PLURAL,
            body=scale_plan_crd.to_dict(),
        )
        self._scaleplan_index += 1
  • 处理异常的 K8sJobErrorMonitor
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
    def _handle_process_error(
        self, node: Node, restart_count: int, error_data: str
    ):
        if restart_count not in self._restart_errors:
            self._restart_errors[restart_count] = error_data
            logger.error(
                f"{node.type}-{node.id} on {node.host_name} "
                f"restart {restart_count} fails: {error_data}"
            )
        return False

    def _handle_node_error(self, node: Node, error_data: str):
        logger.info(
            f"{node.name} on {node.host_name} is down. "
            f"Reason: {error_data}"
        )
        if self.cordon_node_eanbled:
            succeed = self._k8s_client.cordon_node(node.host_name)
            if succeed:
                logger.info(f"Node {node.name} is marked unschedulable.")
        return True

2.2 RendezvousManager

RendezvousManager,构建训练节点的集合。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class ElasticTrainingRendezvousManager(RendezvousManager):

    def get_comm_world(
        self, node_rank
    ) -> Tuple[int, int, Dict[int, NodeTopologyMeta]]:
    """如果一个集合点(rendezvous)轮次完成,则返回通信世界(communication world)。
    当满足以下任一条件时,集合点完成:
    1. 等待节点列表的大小等于最大节点数(max_nodes)。
    2. 等待节点列表的大小大于最小节点数(min_nodes),且等于存活节点列表的大小。此外,在等待超时(waiting_timeout)期间,没有更多的工作节点加入集合点。

    返回值:
    - rdzv_round:轮次索引。
    - group:组索引。
    - world:类似于 {0: 8, 1: 8, 2: 8} 的字典,其中键是节点ID,值是节点的本地世界大小。
    """
    ...

2.3 TaskManager

TaskManager,分配数据分片任务给工作节点,并在工作节点故障时恢复数据分片任务。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class TaskManager(object):
    """创建并分发任务,跟踪任务的生命周期。"""

    def __init__(self, worker_restart_timeout: int, speed_monitor: SpeedMonitor):
        """
        初始化 TaskManager,设置工作节点重启超时时间和速度监控器。
        """

    def new_dataset(
        self,
        batch_size,
        dataset_size,
        dataset_name,
        dataset_splitter: DatasetSplitter,
        task_type=elastic_training_pb2.NONE,
    ):
        """
        创建一个新数据集,并初始化任务管理。
        """

    def get_dataset_task(self, node_type, node_id, dataset_name):
        """
        获取指定数据集、节点类型和节点 ID 的下一个任务。
        """

    def get_dataset(self, dataset_name):
        """
        根据数据集名称获取数据集。
        """
    ...

2.4 MetricCollector

MetricCollector,收集训练任务的指标。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def _create_metric_collector_if_needed(self, params: JobArgs):
        if not params.enable_dynamic_sharding:
            return None
        job_uuid = params.job_uuid
        reporter = ReporterType.LOCAL
        if params.optimize_mode == OptimizeMode.CLUSTER:
            reporter = ReporterType.DLROVER_BRAIN
        collector = JobMetricCollector(
            job_uuid, params.namespace, params.cluster, params.user, reporter
        )
        collector.collect_job_type(params.distribution_strategy)
        return collector

这里有一个判断,params.optimize_modecluster 时,上报的数据会由 BrainReporter 上报存储到 MySQL,否则会由 LocalReporter 上报存储到 DLRover Master 的内存中。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class JobMetricCollector(BaseMetricCollector):
    def collect_dataset_metric(self, name, size, ds_type=DatasetType.TEXT):
        pass
    def def collect_training_hyper_params(self, epoch, batch_size):
        pass
    def collect_job_type(self, job_type):
        pass
    def collect_model_metric(self, model_info: ModelInfo):
        pass
    def _report_runtime_stats(self):
        pass
    def collect_custom_data(self, metric_dict=None):
        pass
    def collect_runtime_stats(
        self, speed_monitor: SpeedMonitor, running_nodes: List[Node]
    ):
        pass
    def report_runtime_stats_periodically(self):
        pass
    def collect_job_exit_reason(self, reason):
        pass

另一方面,在 MasterServicer 中也会接收来自 Agent 使用 gRPC 上报的数据。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class MasterServicer(elastic_training_pb2_grpc.MasterServicer):
    def report(self, request, _):
        message = grpc.deserialize_message(request.data)
        if isinstance(message, grpc.DatasetShardParams):
            success = self._collect_dataset_shard_params(message)
        elif isinstance(message, grpc.ResourceStats):
            success = self._update_node_resource_usage(
                node_type, node_id, message
            )
        elif isinstance(message, grpc.ModelInfo):
            success = self._collect_model_info(message)
        elif isinstance(message, grpc.GlobalStep):
            success = self._collect_global_step(message)
        elif isinstance(message, grpc.ShardCheckpoint):
            success = self._restore_shard_checkpoint(message)
        elif isinstance(message, grpc.TaskResult):
            success = self._report_task_result(message)
        elif isinstance(message, grpc.ClusterVersion):
            success = self._update_cluster_version(message)
        elif isinstance(message, grpc.NodeAddress):
            success = self._update_node_address(message)
        elif isinstance(message, grpc.NetworkStatus):
            success = self._update_node_status(message)
        elif isinstance(message, grpc.NodeEvent):
            success = self._update_node_event(message)
        elif isinstance(message, grpc.SyncJoin):
            success = self._join_sync(node_type, node_id, message)
        elif isinstance(message, grpc.SyncFinish):
            success = self._sync_finished(message)
        elif isinstance(message, grpc.SyncBarrier):
        ...

2.5 ElasticPSService

ElasticPSService 管理 Parameter Server 训练任务中的参数节点。

1
2
3
4
def _create_elastic_ps_service_if_needed(params: JobArgs):
    if params.distribution_strategy == DistributionStrategy.PS:
        return ElasticPsService()
    return None

仅对 Parameter Server 任务有效,主要是管理 PS 任务的版本。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ElasticPsService(object):
    def __init__(self):
        self._global_version = 0
        self._ps_local_version = {}
        self._worker_local_version = {}
        self._worker_restored_version = {}

    def inc_global_cluster_version(self):
        """
        增加全局集群版本号
        """
        pass

    def get_ps_version(self, version_type, ps_id):
        """
        获取参数服务器(PS)的版本

        参数:
        version_type: 版本类型(全局或本地)
        ps_id: 参数服务器的ID
        """
        pass

    def update_ps_version(self, ps_id, version_type, version):
        """
        更新参数服务器(PS)的版本

        参数:
        ps_id: 参数服务器的ID
        version_type: 版本类型(全局或本地)
        version: 要设置的版本号
        """
        pass

    def get_worker_version(self, version_type, worker_id):
        """
        获取工作节点的版本

        参数:
        version_type: 版本类型(全局、本地或恢复的版本)
        worker_id: 工作节点的ID
        """
        pass

    def update_worker_version(self, worker_id, version_type, version):
        """
        更新工作节点的版本

        参数:
        worker_id: 工作节点的ID
        version_type: 版本类型(全局、本地或恢复的版本)
        version: 要设置的版本号
        """
        pass

2.6 运行 prepare 启动 gRPC 和本地 Manager 进程

1
2
3
4
5
6
7
    def prepare(self):
        # 启动 Master 上的 RPC 服务,以供与 Worker 节点通信
        self._master_server.start()
        if self.task_manager:
            self.task_manager.start()
        if self.job_manager:
            self.job_manager.start()

在运行之前,还有需要在 Master 启动 RPC 服务,启动 TaskManager 和 JobManager。

  • 启动 TaskManager
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def start(self):
        if self._worker_restart_timeout > 0:
            threading.Thread(
                target=self._check_and_reassign_timeout_tasks,
                name="check_timeout_tasks",
                daemon=True,
            ).start()
    def _check_and_reassign_timeout_tasks(self):
        """Check whether there are timeout tasks periodically."""
        logger.info("Start the thread to monitor timeout tasks.")
        while True:
            for _, dataset in self._datasets.items():
                # Copy doing task list because the doing list will pop items
                # in the following loop.
                doing_tasks = dataset.doing.copy()
                cur = time.time()
                for task_id, doing_task in doing_tasks.items():
                    start = self._worker_start_task_time.get(
                        doing_task.node_id, cur
                    )
                    if (
                        doing_task.task.task_type
                        == elastic_training_pb2.EVALUATION
                        and cur - start
                        > max(
                            _TASK_TIMEOUT_THRESHOLD_SECS,
                            self._worker_restart_timeout,
                        )
                    ):
                        logger.info(
                            f"The task {task_id} of {doing_task.node_type}-"
                            f"{doing_task.node_id} is timeout."
                        )
                        dataset.report_task_status(task_id, success=False)
                        self._invoke_task_timeout_callback(doing_task.node_id)
                        break
            time.sleep(30)
  • 启动 JobManager
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    def start(self):
        self._scaler.start()
        self._job_optimizer.update_job_uuid(self._job_args.job_uuid)
        self._job_optimizer.init_job_resource(self._job_resource)
        self._adjust_worker_for_estimator()
        self._init_nodes()
        self._init_job_auto_scaler()
        plan = self._create_initial_scale_plan()
        if not self._has_running_workers():
            # The the job relaunches the evicted master, there are alive
            # worker nodes and the master does not need to launch workers.
            self._scaler.scale(plan)
        else:
            logger.info(
                "The recovered master skips launching workers at begining."
            )
        worker_num = 0
        if NodeType.WORKER in plan.node_group_resources:
            worker_num = plan.node_group_resources[NodeType.WORKER].count
        if NodeType.CHIEF in plan.node_group_resources:
            worker_num += plan.node_group_resources[NodeType.CHIEF].count
        self._speed_monitor.set_target_worker_num(worker_num)
        threading.Thread(
            target=self._monitor_nodes, name="node_monitor", daemon=True
        ).start()
        threading.Thread(
            target=self._monitor_node_heart_beat,
            name="node_heart_beat_monitor",
            daemon=True,
        ).start()
        if os.getenv("KUBERNETES_SERVICE_HOST"):
            threading.Thread(
                target=self._monitor_scale_plan_crd,
                name="scaleplan_monitor",
                daemon=True,
            ).start()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
    def _monitor_nodes(self):
        logger.info("Start monitoring nodes events.")
        while True:
            try:
                nodes = self._node_watcher.list()
                self._process_list_nodes(nodes)
                if self._stop_monitor:
                    logger.info("Stop processing node events")
                    break
                # watch pod 的状态,并封装为 NodeEvent,给 _process_event 统一处理
                for event in self._node_watcher.watch():
                    try:
                        self._process_event(event)
                    except Exception as e:
                        logger.warning(e)
                        detail_trace_back = traceback.format_exc()
                        logger.warning(detail_trace_back)
            except Exception as e:
                logger.warning(e)
                time.sleep(30)
            time.sleep(5)

处理节点的心跳事件

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
    def _monitor_node_heart_beat(self):
        logger.info("Start monitoring the heart beat of nodes.")
        while True:
            with self._lock:
                events = self._get_dead_node_event()
            # 超过 300s 没有响应,则认为节点异常
            for event in events:
                try:
                    self._process_event(event)
                except Exception as e:
                    logger.warning(e)
                    detail_trace_back = traceback.format_exc()
                    logger.warning(detail_trace_back)
            time.sleep(15)

这里的 _get_dead_node_event 就是获取异常的 Pod 的事件。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    def _get_dead_node_event(self, window_interval=300) -> List[NodeEvent]:
        now = time.time()
        dead_events = []
        for _, nodes in self._job_nodes.items():
            for _, node in nodes.items():
                if (
                    node.heartbeat_time > 0
                    and now - node.heartbeat_time > window_interval
                    and node.status == NodeStatus.RUNNING
                ):
                    event_node = copy.deepcopy(node)
                    event_node.status = NodeStatus.FAILED
                    event_node.exit_reason = NodeExitReason.NO_HEARTBEAT
                    event = NodeEvent(
                        event_type=NodeEventType.DELETED,
                        node=event_node,
                    )
                    dead_events.append(event)
                    error_data = (
                        f"No heartbeat for over {window_interval} seconds."
                    )
                    self._error_monitor.process_error(
                        node,
                        node.relaunch_count,
                        error_data,
                        TrainingExceptionLevel.NODE_ERROR,
                    )
                    logger.warning(
                        f"The node {node.name} has not sent a heartbeat "
                        f"for over {window_interval} seconds."
                    )
        return dead_events

这里可以看到最核心的是调用了 _process_event 以及 _process_node_events 函数。

2.7 异常处理逻辑

1
2
3
4
5
6
7
    def _process_event(self, event: NodeEvent):
        with self._lock:
            should_relaunch = self._should_relaunch(
                cur_node, status_change_flow
            )
        if should_relaunch:
            self._relaunch_node(cur_node)

第一步,判断是否需要重启

需要注意,这里 Node 的 exit_reason 是对 Pod 的封装和转换,而不是直接代表 Pod 的状态。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    def _should_relaunch(self, node: Node, status_change_flow: NodeStateFlow):
        should_relaunch = (
            status_change_flow.should_relaunch
            and self._enable_relaunch_node
            and node.relaunchable
        )
        if should_relaunch:
            # 排除一些特殊情况,Error、OOM、超过最大重启次数、Killed
            if (
                node.exit_reason == NodeExitReason.FATAL_ERROR
                and not _dlrover_context.relaunch_always
            ):
                should_relaunch = False
            elif node.exit_reason == NodeExitReason.OOM:
                mem = node.config_resource.memory
                if mem >= NodeResourceLimit.MAX_MEMORY:
                    should_relaunch = False
                    logger.warning(
                        "The memory of worker %s is beyond the limit %s MB.",
                        mem,
                        NodeResourceLimit.MAX_MEMORY,
                    )
                elif node.relaunch_count >= node.max_relaunch_count:
                    should_relaunch = False
                    logger.warning(
                        "The relaunched count %s is beyond the maximum %s.",
                        node.relaunch_count,
                        node.max_relaunch_count,
                    )
                else:
                    node.is_recovered_oom = True
                    self._job_optimizer.adjust_oom_resource(node)
            elif node.exit_reason != NodeExitReason.KILLED:
                if node.relaunch_count >= node.max_relaunch_count:
                    logger.warning(
                        "The relaunch count "
                        f"{node.relaunch_count}/{node.max_relaunch_count} "
                        "has been exhausted."
                    )
                    should_relaunch = False

        return should_relaunch

第二步,重启 Node 节点,即 Pod

在 AllReduce 策略下,就是创建 Worker 节点。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    def _relaunch_node(self, node: Node):
        if node.type == NodeType.WORKER:
            plan = self._worker_manager.relaunch_node(
                node, self._remove_exited_node
            )
        elif node.type == NodeType.PS:
            plan = self._ps_manager.relaunch_node(
                node, self._remove_exited_node
            )
        elif node.type == NodeType.EVALUATOR:
            plan = self._evaluator_manager.relaunch_node(
                node, self._remove_exited_node
            )
        elif node.type == NodeType.CHIEF or node.type == NodeType.MASTER:
            plan = self._chief_manager.relaunch_node(
                node, self._remove_exited_node
            )
        else:
            logger.error("Not support node type %s", node.type)
        self._set_ps_addrs_in_plan(plan)
        if self._remove_exited_node:
            plan.remove_nodes.append(node)
        node.relaunchable = False  # Avoid repeatedly relaunching the node.
        self._scaler.scale(plan)

在创建一个 ScalePlan 时,需要获取旧节点的 rank_indexservice_addr 等信息,用于创建新的 Pod 节点。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
    def relaunch_node(self, node: Node, remove_exited_node=False):
        plan = ScalePlan()
        with self._lock:
            new_id = next(self._node_id_iter)
            relaunch_node = node.get_relaunch_node_info(new_id)
            self._nodes[new_id] = relaunch_node
        logger.info("Relaunch node %s to %s", node.name, new_id)
        plan.launch_nodes.append(
            Node(
                node.type,
                new_id,
                copy.deepcopy(relaunch_node.config_resource),
                rank_index=node.rank_index,
                name=self._new_node_name_fn(node.type, new_id),
                service_addr=node.service_addr,
                relaunch_count=relaunch_node.relaunch_count,
            )
        )
        if remove_exited_node and not node.is_released and node.exited():
            node.is_released = True
            plan.remove_nodes.append(node)
        return plan

ScalePlan 不会提交为 CR 对象给 K8s,而是根据不同运行时给了不同的 _scaler

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class PodScaler(Scaler):
    def scale(self, plan: ScalePlan):
        with self._lock:
            for type, group_resource in plan.node_group_resources.items():
                if group_resource.count > len(cur_pods):
                    self._scale_up_pods(type, plan, cur_pods, max_pod_id)
                elif group_resource.count < len(cur_pods):
                    self._scale_down_pods(type, plan, cur_pods)
            for node in plan.launch_nodes:
                self._create_node_queue.append(node)
            self._update_job_pods(job_pods)

在 K8s 下 Worker 节点都是 PodScaler 创建出来的。 _scale_up_pods 会将创建的 Pod 节点加入到 _create_node_queue 中。

2.8 运行 run 启动任务

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def run(self):
        """
        The main loop of master.
        Dispatch the tasks to the workers until all the tasks are completed.
        """
        try:
            while True:
                if self._stop_requested:
                    break
                msg = self.job_manager.early_stop()
                if msg:
                    self.request_stop(False, msg)
                    continue
                self.job_manager.clear_exited_nodes()
                if self.job_manager and self.job_manager.all_workers_exited():
                    if self.job_manager.pend_without_workers():
                        time.sleep(30)
                        continue
                    if self.job_manager.all_workers_failed():
                        logger.error("All workers failed")
                        self._exit_code = 1
                        self._exit_reason = JobExitReason.UNKNOWN_ERROR
                    elif (
                        self.task_manager and not self.task_manager.finished()
                    ):
                        logger.warning(
                            "All workers exited but there also are "
                            "unfinished tasks",
                        )
                    break

                if (
                    self.job_manager.all_running_node_hanged()
                    and self.task_manager.task_hanged()
                ):
                    logger.error("All nodes hangeds")
                    self._exit_code = 1
                    self._exit_reason = JobExitReason.HANG_ERROR

                if (
                    self.task_manager
                    and self.task_manager.finished()
                    and (
                        not self.job_manager
                        or self.job_manager.all_critical_node_completed()
                    )
                ):
                    logger.info("All task completed")
                    break

                time.sleep(30)

master 会执行一个死循环,每隔 30s 检测一次状态。具有如下状态时,master 会退出:

  • 收到停止请求 self._stop_requested
  • 所有 worker 已退出 self.job_manager.all_workers_exited()
  • task_manager 已完成 self.task_manager.finished()

2.9 小结

DLRover 容错的逻辑主要在 Master 中,而其中的关键在 JobManager。

监控的数据源有如下几类:

  • Agent 上报的数据,包括指标、训练速度等
  • Master 从 K8s 获取的 Pod 事件数据

JobManager 基于这些上报的数据,封装为 NodeEvent 对象,然后统一由 _process_event 处理。

3. DLRover Trainer

3.1 启动脚本

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
cat /usr/local/bin/dlrover-run

#!/usr/local/bin/python
# -*- coding: utf-8 -*-
import re
import sys
from dlrover.trainer.torch.main import main
if __name__ == '__main__':
    sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
    sys.exit(main())

3.2 入口函数

dlrover.trainer.torch.main 调用的是 dlrover.trainer.torch.elastic_run.main

1
2
3
4
@record
def main(args=None):
    args = parse_args(args)
    run(args)

包了一层入口而已,没有实际逻辑。

3.3 参数解析

--network-check,在训练之前,先检查网络状态。
--node_unit,设置节点单元数量,调度的节点数应为此数量的倍数。
--auto_config,自动配置节点和每个节点的进程数。
--auto_tunning,自动调整并行配置。
--exclude-straggler,排除落后节点,仅在 network_check 开启时有效。
--save_at_breakpoint,训练失败时保存检查点到内存。
--accelerator,设置机器的加速器类型,如 nvidia.com/gpuascend-npu

3.4 启动训练任务

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def run(args):
    # 连接 DLRover Master
    dlrover_master_ready = grpc.addr_connected(master_addr)
    _, max_nodes = parse_min_max_nnodes(args.nnodes)
    # 如果没有就绪,并且 `node_rank == 0` 就将当前节点作为 DLRover Master 启动
    if not dlrover_master_ready and node_rank == 0:
        # Only start the dlrover master on the rank-0 node.
        master_handler, master_addr = _launch_dlrover_local_master(
            master_addr,
            job_name,
            max_nodes,
        )
        logger.info(f"Set the dlrover master addr as {master_addr}")
        os.environ[NodeEnv.DLROVER_MASTER_ADDR] = master_addr
    use_dlrover_launch = _check_to_use_dlrover_run(master_addr, max_nodes)

    if args.standalone and not use_dlrover_launch:
        args.rdzv_backend = "c10d"
        args.rdzv_endpoint = "localhost:29400"
        args.rdzv_id = str(uuid.uuid4())
        logger.info(
            f"\n**************************************\n"
            f"Rendezvous info:\n"
            f"--rdzv-backend={args.rdzv_backend} "
            f"--rdzv-endpoint={args.rdzv_endpoint} "
            f"--rdzv-id={args.rdzv_id}\n"
            f"**************************************\n"
        )
    # 解析训练参数
    config, cmd, cmd_args = _elastic_config_from_args(args)
    config.run_id = job_name
    config.role = "dlrover-trainer"
    try:
        # 启动训练
        elastic_launch(
            config=config,
            entrypoint=cmd,
            use_dlrover_launch=use_dlrover_launch,
        )(*cmd_args)
    finally:
        if master_handler:
            master_handler.close()
1
2
3
4
5
6
7
8
class elastic_launch:
    def __call__(self, *args):
        if self._use_dlrover_launch:
            return launch_agent(self._config, self._entrypoint, list(args))
        else:
            return torch_launch_agent(
                self._config, self._entrypoint, list(args)
            )

如果没有使用 dlrover 托管就直接使用 from torch.distributed.launcher.api import launch_agent as torch_launch_agent 启动训练。否则使用下面的函数启动训练:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def launch_agent(
    config: ElasticLaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    # 生成唯一的 `run_id`
    if not config.run_id:
        run_id = str(uuid.uuid4().int)
        logger.warning(
            f"config has no run_id, generated a random run_id: {run_id}"
        )
        config.run_id = run_id
    # 初始化监控
    monitor = TorchTrainingMonitor(ConfigPath.RUNTIME_METRICS)
    monitor.start()
    # 初始化 Agent
    ...
    agent = ElasticTrainingAgent(
        node_rank=node_rank,
        config=config,
        entrypoint=entrypoint,
        spec=spec,
        start_method=config.start_method,
        log_dir=config.log_dir,
    )
    try:
        metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
        # 启动 agent
        result = agent.run()
    ...
1
2
class ElasticTrainingAgent(LocalElasticAgent):
    ...

可以看到 DLRover 使用了 PyTorch 内置的 LocalElasticAgent 负责管理节点上的训练进程。

3.5 数据上报

从上面可以看到,在启动 Agent 时,会启动一个监控任务的进程。

1
2
3
4
5
6
7
def launch_agent(
    config: ElasticLaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:
    monitor = TorchTrainingMonitor(ConfigPath.RUNTIME_METRICS)
    monitor.start()

在这个监控进程中,会周期性地上报节点的资源使用情况。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class TorchTrainingMonitor(Singleton):
    def start(self):
        if os.getenv(NodeEnv.MONITOR_ENABLED, "false") != "true":
            return
        self._resource_monitor.start()
        thread = threading.Thread(
            target=self._periodically_report,
            name="report-step",
            daemon=True,
        )
        thread.start()
    def _periodically_report(self):
        while True:
            if self._group_rank == 0:
                self.report_resource_with_step()
            self.send_heartbeat()
            time.sleep(15)

有两个上报数据的流程:

第一种是,在 Pod 中使用 psutil 获取资源使用情况。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class ResourceMonitor(Singleton):
    def report_resource(self):
        try:
            used_mem = get_used_memory()
            cpu_percent = get_process_cpu_percent()
            if self._gpu_enabled:
                self._gpu_stats = get_gpu_stats()
            current_cpu = round(cpu_percent * self._total_cpu, 2)
            self._master_client.report_used_resource(
                used_mem, current_cpu, self._gpu_stats
            )
            logger.debug(
                "Report Resource CPU : %s, Memory %s, GPU %s",
                current_cpu,
                used_mem,
                self._gpu_stats,
            )
        except Exception as e:
            logger.exception(e)

第二种是,从 /tmp/dlrover/runtime_metrics.json 文件中获取,训练速度。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
    def report_resource_with_step(self):
        if self._group_rank != 0:
            return
        try:
            if not os.path.exists(self._metrics_path):
                return
            with open(self._metrics_path, "r") as f:
                record = json.load(f)
                step = record.get("step", 0)
                timestamp = record.get("timestamp", 0)
            if step > 0 and timestamp - self._last_timestamp > 15:
                self._resource_monitor.report_resource()
                self._last_timestamp = timestamp
                self._master_client.report_global_step(
                    step,
                    self._last_timestamp,
                )
        except Exception as e:
            logger.warning(e)

ElasticTrainer 会从记录的梯度状态中,将 num_stepstimestamp 写入到约定的指标文件。

1
2
3
4
5
6
7
8
9
class ElasticTrainer(object):
    def report_training_step(self):
        timestamp = time.time()
        record = TrainingRecord(self.gradient_state.num_steps, timestamp)
        metric_path = os.getenv(ConfigPath.ENV_RUNTIME_METRICS, "")
        rank = get_rank()
        if os.path.exists(os.path.dirname(metric_path)) and rank == 0:
            with open(metric_path, "w") as f:
                f.write(record.to_json(indent=4))

3.6 小结

Trainer 主要有两个功能:

  1. 使用 LocalElasticAgent 管理节点上的训练进程
  2. 使用 gRPC 进行数据的上报,包括训练速度、资源使用情况

4. 总结

本篇分析了 DLRover 在 Kubernetes 下的实现细节,主要涉及 AllReduce 策略下的训练任务,跳过了 PS 任务以及 Brain 相关内容:

  1. DLRover Operator 定义了作业和扩容相关的字段,维护相关状态。只是启动 DLRover Master,没有弹性、容错相关的逻辑实现
  2. 每个训练任务都会启动一个 DLRover Master,掌控着整个训练节奏,其中:
    • JobMnager 用于管理作业的启停、容错、扩缩容
    • RenderzerManager 用于节点的组网
    • TaskManager 用于管理数据分片
    • MetricCollector 用于收集训练指标
    • ElasticPSService 用于管理 PS 任务中的参数节点
  3. DLRover 在处理异常时,会将检测状态封装为 NodeEvent,通过 DLRover Master 中的 _process_event 来统一处理
  4. 使用 dlrover-run 脚本启动训练任务时,DLRover 会使用 Pytorch 中的 LocalElasticAgent 管理节点上的训练进程; 同时启动一个监控进程,将训练相关的指标通过 gRPC 上报给 DLRover Master

微信公众号
作者
微信公众号