## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromtypingimportClassVar,Type,Dict,List,Optional,Union,cast,TYPE_CHECKINGfrompyspark.utilimportlocal_connect_and_authfrompyspark.serializersimportread_int,write_int,write_with_length,UTF8Deserializerfrompyspark.errorsimportPySparkRuntimeErrorifTYPE_CHECKING:frompyspark.resourceimportResourceInformation
[docs]classTaskContext:""" Contextual information about a task which can be read or mutated during execution. To access the TaskContext for a running task, use: :meth:`TaskContext.get`. .. versionadded:: 2.2.0 Examples -------- >>> from pyspark import TaskContext Get a task context instance from :class:`RDD`. >>> spark.sparkContext.setLocalProperty("key1", "value") >>> taskcontext = spark.sparkContext.parallelize([1]).map(lambda _: TaskContext.get()).first() >>> isinstance(taskcontext.attemptNumber(), int) True >>> isinstance(taskcontext.partitionId(), int) True >>> isinstance(taskcontext.stageId(), int) True >>> isinstance(taskcontext.taskAttemptId(), int) True >>> taskcontext.getLocalProperty("key1") 'value' >>> isinstance(taskcontext.cpus(), int) True Get a task context instance from a dataframe via Python UDF. >>> from pyspark.sql import Row >>> from pyspark.sql.functions import udf >>> @udf("STRUCT<anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>") ... def taskcontext_as_row(): ... taskcontext = TaskContext.get() ... return Row( ... anum=taskcontext.attemptNumber(), ... partid=taskcontext.partitionId(), ... stageid=taskcontext.stageId(), ... taskaid=taskcontext.taskAttemptId(), ... prop=taskcontext.getLocalProperty("key2"), ... cpus=taskcontext.cpus()) ... >>> spark.sparkContext.setLocalProperty("key2", "value") >>> [(anum, partid, stageid, taskaid, prop, cpus)] = ( ... spark.range(1).select(taskcontext_as_row()).first() ... ) >>> isinstance(anum, int) True >>> isinstance(partid, int) True >>> isinstance(stageid, int) True >>> isinstance(taskaid, int) True >>> prop 'value' >>> isinstance(cpus, int) True Get a task context instance from a dataframe via Pandas UDF. >>> import pandas as pd # doctest: +SKIP >>> from pyspark.sql.functions import pandas_udf >>> @pandas_udf("STRUCT<" ... "anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>") ... def taskcontext_as_row(_): ... taskcontext = TaskContext.get() ... return pd.DataFrame({ ... "anum": [taskcontext.attemptNumber()], ... "partid": [taskcontext.partitionId()], ... "stageid": [taskcontext.stageId()], ... "taskaid": [taskcontext.taskAttemptId()], ... "prop": [taskcontext.getLocalProperty("key3")], ... "cpus": [taskcontext.cpus()] ... }) # doctest: +SKIP ... >>> spark.sparkContext.setLocalProperty("key3", "value") # doctest: +SKIP >>> [(anum, partid, stageid, taskaid, prop, cpus)] = ( ... spark.range(1).select(taskcontext_as_row("id")).first() ... ) # doctest: +SKIP >>> isinstance(anum, int) True >>> isinstance(partid, int) True >>> isinstance(stageid, int) True >>> isinstance(taskaid, int) True >>> prop 'value' >>> isinstance(cpus, int) True """_taskContext:ClassVar[Optional["TaskContext"]]=None_attemptNumber:Optional[int]=None_partitionId:Optional[int]=None_stageId:Optional[int]=None_taskAttemptId:Optional[int]=None_localProperties:Optional[Dict[str,str]]=None_cpus:Optional[int]=None_resources:Optional[Dict[str,"ResourceInformation"]]=Nonedef__new__(cls:Type["TaskContext"])->"TaskContext":""" Even if users construct :class:`TaskContext` instead of using get, give them the singleton. """taskContext=cls._taskContextiftaskContextisnotNone:returntaskContextcls._taskContext=taskContext=object.__new__(cls)returntaskContext@classmethoddef_getOrCreate(cls:Type["TaskContext"])->"TaskContext":"""Internal function to get or create global :class:`TaskContext`."""ifcls._taskContextisNone:cls._taskContext=TaskContext()returncls._taskContext@classmethoddef_setTaskContext(cls:Type["TaskContext"],taskContext:"TaskContext")->None:cls._taskContext=taskContext
[docs]@classmethoddefget(cls:Type["TaskContext"])->Optional["TaskContext"]:""" Return the currently active :class:`TaskContext`. This can be called inside of user functions to access contextual information about running tasks. Returns ------- :class:`TaskContext`, optional Notes ----- Must be called on the worker, not the driver. Returns ``None`` if not initialized. """returncls._taskContext
[docs]defstageId(self)->int:""" The ID of the stage that this task belong to. Returns ------- int current stage id. """returncast(int,self._stageId)
[docs]defpartitionId(self)->int:""" The ID of the RDD partition that is computed by this task. Returns ------- int current partition id. """returncast(int,self._partitionId)
[docs]defattemptNumber(self)->int:""" How many times this task has been attempted. The first task attempt will be assigned attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. Returns ------- int current attempt number. """returncast(int,self._attemptNumber)
[docs]deftaskAttemptId(self)->int:""" An ID that is unique to this task attempt (within the same :class:`SparkContext`, no two task attempts will share the same attempt ID). This is roughly equivalent to Hadoop's `TaskAttemptID`. Returns ------- int current task attempt id. """returncast(int,self._taskAttemptId)
[docs]defgetLocalProperty(self,key:str)->Optional[str]:""" Get a local property set upstream in the driver, or None if it is missing. Parameters ---------- key : str the key of the local property to get. Returns ------- int the value of the local property. """returncast(Dict[str,str],self._localProperties).get(key,None)
[docs]defcpus(self)->int:""" CPUs allocated to the task. Returns ------- int the number of CPUs. """returncast(int,self._cpus)
[docs]defresources(self)->Dict[str,"ResourceInformation"]:""" Resources allocated to the task. The key is the resource name and the value is information about the resource. Returns ------- dict a dictionary of a string resource name, and :class:`ResourceInformation`. """frompyspark.resourceimportResourceInformationreturncast(Dict[str,"ResourceInformation"],self._resources)
BARRIER_FUNCTION=1ALL_GATHER_FUNCTION=2def_load_from_socket(port:Optional[Union[str,int]],auth_secret:str,function:int,all_gather_message:Optional[str]=None,)->List[str]:""" Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """(sockfile,sock)=local_connect_and_auth(port,auth_secret)# The call may block forever, so no timeoutsock.settimeout(None)iffunction==BARRIER_FUNCTION:# Make a barrier() function call.write_int(function,sockfile)eliffunction==ALL_GATHER_FUNCTION:# Make a all_gather() function call.write_int(function,sockfile)write_with_length(cast(str,all_gather_message).encode("utf-8"),sockfile)else:raiseValueError("Unrecognized function type")sockfile.flush()# Collect result.len=read_int(sockfile)res=[]foriinrange(len):res.append(UTF8Deserializer().loads(sockfile))# Release resources.sockfile.close()sock.close()returnres
[docs]classBarrierTaskContext(TaskContext):""" A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. .. versionadded:: 2.4.0 Notes ----- This API is experimental Examples -------- Set a barrier, and execute it with RDD. >>> from pyspark import BarrierTaskContext >>> def block_and_do_something(itr): ... taskcontext = BarrierTaskContext.get() ... # Do something. ... ... # Wait until all tasks finished. ... taskcontext.barrier() ... ... return itr ... >>> rdd = spark.sparkContext.parallelize([1]) >>> rdd.barrier().mapPartitions(block_and_do_something).collect() [1] """_port:ClassVar[Optional[Union[str,int]]]=None_secret:ClassVar[Optional[str]]=None@classmethoddef_getOrCreate(cls:Type["BarrierTaskContext"])->"BarrierTaskContext":""" Internal function to get or create global :class:`BarrierTaskContext`. We need to make sure :class:`BarrierTaskContext` is returned from here because it is needed in python worker reuse scenario, see SPARK-25921 for more details. """ifnotisinstance(cls._taskContext,BarrierTaskContext):cls._taskContext=object.__new__(cls)returncls._taskContext
[docs]@classmethoddefget(cls:Type["BarrierTaskContext"])->"BarrierTaskContext":""" Return the currently active :class:`BarrierTaskContext`. This can be called inside of user functions to access contextual information about running tasks. Notes ----- Must be called on the worker, not the driver. Returns ``None`` if not initialized. An Exception will raise if it is not in a barrier stage. This API is experimental """ifnotisinstance(cls._taskContext,BarrierTaskContext):raisePySparkRuntimeError(error_class="NOT_IN_BARRIER_STAGE",message_parameters={},)returncls._taskContext
@classmethoddef_initialize(cls:Type["BarrierTaskContext"],port:Optional[Union[str,int]],secret:str)->None:""" Initialize :class:`BarrierTaskContext`, other methods within :class:`BarrierTaskContext` can only be called after BarrierTaskContext is initialized. """cls._port=portcls._secret=secret
[docs]defbarrier(self)->None:""" Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks in the same stage have reached this routine. .. versionadded:: 2.4.0 Notes ----- This API is experimental In a barrier stage, each task much have the same number of `barrier()` calls, in all possible code branches. Otherwise, you may get the job hanging or a `SparkException` after timeout. """ifself._portisNoneorself._secretisNone:raisePySparkRuntimeError(error_class="CALL_BEFORE_INITIALIZE",message_parameters={"func_name":"barrier","object":"BarrierTaskContext",},)else:_load_from_socket(self._port,self._secret,BARRIER_FUNCTION)
[docs]defallGather(self,message:str="")->List[str]:""" This function blocks until all tasks in the same stage have reached this routine. Each task passes in a message and returns with a list of all the messages passed in by each of those tasks. .. versionadded:: 3.0.0 Notes ----- This API is experimental In a barrier stage, each task much have the same number of `barrier()` calls, in all possible code branches. Otherwise, you may get the job hanging or a `SparkException` after timeout. """ifnotisinstance(message,str):raiseTypeError("Argument `message` must be of type `str`")elifself._portisNoneorself._secretisNone:raisePySparkRuntimeError(error_class="CALL_BEFORE_INITIALIZE",message_parameters={"func_name":"allGather","object":"BarrierTaskContext",},)else:return_load_from_socket(self._port,self._secret,ALL_GATHER_FUNCTION,message)
[docs]defgetTaskInfos(self)->List["BarrierTaskInfo"]:""" Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, ordered by partition ID. .. versionadded:: 2.4.0 Notes ----- This API is experimental Examples -------- >>> from pyspark import BarrierTaskContext >>> rdd = spark.sparkContext.parallelize([1]) >>> barrier_info = rdd.barrier().mapPartitions( ... lambda _: [BarrierTaskContext.get().getTaskInfos()]).collect()[0][0] >>> barrier_info.address '...:...' """ifself._portisNoneorself._secretisNone:raisePySparkRuntimeError(error_class="CALL_BEFORE_INITIALIZE",message_parameters={"func_name":"getTaskInfos","object":"BarrierTaskContext",},)else:addresses=cast(Dict[str,str],self._localProperties).get("addresses","")return[BarrierTaskInfo(h.strip())forhinaddresses.split(",")]
[docs]classBarrierTaskInfo:""" Carries all task infos of a barrier task. .. versionadded:: 2.4.0 Attributes ---------- address : str The IPv4 address (host:port) of the executor that the barrier task is running on Notes ----- This API is experimental """def__init__(self,address:str)->None:self.address=address