|
16 | 16 | # under the License.
|
17 | 17 | from __future__ import annotations
|
18 | 18 |
|
19 |
| -from unittest.mock import patch |
| 19 | +from unittest.mock import MagicMock, patch |
20 | 20 |
|
21 | 21 | import pytest
|
| 22 | +from kubernetes.client import ( |
| 23 | + V1ContainerState, |
| 24 | + V1ContainerStateWaiting, |
| 25 | + V1ContainerStatus, |
| 26 | + V1Pod, |
| 27 | + V1PodStatus, |
| 28 | +) |
22 | 29 |
|
23 | 30 | from airflow.exceptions import AirflowException
|
24 | 31 | from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import (
|
| 32 | + CustomObjectLauncher, |
25 | 33 | SparkJobSpec,
|
26 | 34 | SparkResources,
|
27 | 35 | )
|
28 | 36 |
|
29 | 37 |
|
| 38 | +@pytest.fixture |
| 39 | +def mock_launcher(): |
| 40 | + launcher = CustomObjectLauncher( |
| 41 | + name="test-spark-job", |
| 42 | + namespace="default", |
| 43 | + kube_client=MagicMock(), |
| 44 | + custom_obj_api=MagicMock(), |
| 45 | + template_body={ |
| 46 | + "spark": { |
| 47 | + "spec": { |
| 48 | + "image": "gcr.io/spark-operator/spark-py:v3.0.0", |
| 49 | + "driver": {}, |
| 50 | + "executor": {}, |
| 51 | + }, |
| 52 | + "apiVersion": "sparkoperator.k8s.io/v1beta2", |
| 53 | + "kind": "SparkApplication", |
| 54 | + }, |
| 55 | + }, |
| 56 | + ) |
| 57 | + launcher.pod_spec = V1Pod() |
| 58 | + return launcher |
| 59 | + |
| 60 | + |
30 | 61 | class TestSparkJobSpec:
|
31 | 62 | @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.update_resources")
|
32 | 63 | @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.SparkJobSpec.validate")
|
@@ -150,3 +181,40 @@ def test_spark_resources_conversion(self):
|
150 | 181 | assert spark_resources.executor["cpu"]["limit"] == "4"
|
151 | 182 | assert spark_resources.driver["gpu"]["quantity"] == 1
|
152 | 183 | assert spark_resources.executor["gpu"]["quantity"] == 2
|
| 184 | + |
| 185 | + |
| 186 | +class TestCustomObjectLauncher: |
| 187 | + def get_pod_status(self, reason: str, message: str | None = None): |
| 188 | + return V1PodStatus( |
| 189 | + container_statuses=[ |
| 190 | + V1ContainerStatus( |
| 191 | + image="test", |
| 192 | + image_id="test", |
| 193 | + name="test", |
| 194 | + ready=False, |
| 195 | + restart_count=0, |
| 196 | + state=V1ContainerState( |
| 197 | + waiting=V1ContainerStateWaiting( |
| 198 | + reason=reason, |
| 199 | + message=message, |
| 200 | + ), |
| 201 | + ), |
| 202 | + ), |
| 203 | + ] |
| 204 | + ) |
| 205 | + |
| 206 | + @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager") |
| 207 | + def test_check_pod_start_failure_no_error(self, mock_pod_manager, mock_launcher): |
| 208 | + mock_pod_manager.return_value.read_pod.return_value.status = self.get_pod_status("ContainerCreating") |
| 209 | + mock_launcher.check_pod_start_failure() |
| 210 | + |
| 211 | + mock_pod_manager.return_value.read_pod.return_value.status = self.get_pod_status("PodInitializing") |
| 212 | + mock_launcher.check_pod_start_failure() |
| 213 | + |
| 214 | + @patch("airflow.providers.cncf.kubernetes.operators.custom_object_launcher.PodManager") |
| 215 | + def test_check_pod_start_failure_with_error(self, mock_pod_manager, mock_launcher): |
| 216 | + mock_pod_manager.return_value.read_pod.return_value.status = self.get_pod_status( |
| 217 | + "CrashLoopBackOff", "Error message" |
| 218 | + ) |
| 219 | + with pytest.raises(AirflowException): |
| 220 | + mock_launcher.check_pod_start_failure() |
0 commit comments