|
|
import gc
|
|
|
from pydoc import isdata
|
|
|
import random
|
|
|
import string
|
|
|
import threading
|
|
|
import time
|
|
|
import traceback
|
|
|
from types import LambdaType
|
|
|
from typing import Dict, Optional, List, Any, Union
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
|
from async_timeout import asyncio
|
|
|
|
|
|
from hydra_node.models import torch_gc
|
|
|
|
|
|
class TaskQueueFullException(Exception):
|
|
|
pass
|
|
|
|
|
|
class TaskDataType(TypedDict):
|
|
|
task_id: str
|
|
|
position: int
|
|
|
status: str
|
|
|
updated_time: float
|
|
|
callback: LambdaType
|
|
|
current_step: int
|
|
|
total_steps: int
|
|
|
payload: Any
|
|
|
response: Any
|
|
|
|
|
|
class TaskData:
|
|
|
_task_queue = None
|
|
|
|
|
|
task_id: str = ""
|
|
|
position: int = 0
|
|
|
status: str = "preparing"
|
|
|
updated_time: float = 0
|
|
|
callback: LambdaType = lambda x: x
|
|
|
current_step: int = 0
|
|
|
total_steps: int = 0
|
|
|
payload: Any = {}
|
|
|
response: Any = None
|
|
|
|
|
|
def __init__(self, taskQueue, initData: Dict = {}):
|
|
|
self._task_queue = taskQueue
|
|
|
self.update(initData, skipEvent=True)
|
|
|
|
|
|
def update(self, data: Dict, skipEvent: bool = False):
|
|
|
for (key, value) in data.items():
|
|
|
if key[0] != "_" and hasattr(self, key):
|
|
|
setattr(self, key, value)
|
|
|
else:
|
|
|
raise Exception("Property %s not found on TaskData" % (key))
|
|
|
|
|
|
if skipEvent == False:
|
|
|
self._task_queue._on_task_update(self, data)
|
|
|
|
|
|
def to_dict(self):
|
|
|
allowedItems = ["task_id", "position", "status", "updated_time", "current_step", "total_steps"]
|
|
|
retDict = {}
|
|
|
for itemKey in allowedItems:
|
|
|
retDict[itemKey] = getattr(self, itemKey)
|
|
|
|
|
|
return retDict
|
|
|
|
|
|
class TaskQueue:
|
|
|
logger = None
|
|
|
|
|
|
max_queue_size = 10
|
|
|
recovery_queue_size = 6
|
|
|
is_busy = False
|
|
|
|
|
|
loop_thread = None
|
|
|
gc_loop_thread = None
|
|
|
event_loop_thread = None
|
|
|
running = False
|
|
|
core_running = False
|
|
|
|
|
|
lock = threading.Lock()
|
|
|
event_lock = threading.Lock()
|
|
|
|
|
|
event_trigger = threading.Event()
|
|
|
event_list = []
|
|
|
|
|
|
queued_task_list: List[TaskData] = []
|
|
|
task_map: Dict[str, TaskData] = {}
|
|
|
|
|
|
on_update = lambda queue_data: None
|
|
|
|
|
|
def __init__(self, logger, max_queue_size = 10, recovery_queue_size = 6) -> None:
|
|
|
self.logger = logger
|
|
|
|
|
|
self.max_queue_size = max_queue_size
|
|
|
self.max_queue_size = recovery_queue_size
|
|
|
|
|
|
self.core_running = True
|
|
|
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
|
|
|
self.gc_loop_thread.start()
|
|
|
self.event_loop_thread = threading.Thread(name="TaskQueueEvent", target=self._event_loop)
|
|
|
self.event_loop_thread.start()
|
|
|
|
|
|
logger.info("Task queue created")
|
|
|
|
|
|
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
|
|
|
if self.is_busy:
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
|
|
|
self.is_busy = True
|
|
|
raise TaskQueueFullException("Task queue is full")
|
|
|
|
|
|
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
|
|
|
task = TaskData(self, {
|
|
|
"task_id": task_id,
|
|
|
"position": 0,
|
|
|
"status": "queued",
|
|
|
"updated_time": time.time(),
|
|
|
"callback": callback,
|
|
|
"current_step": 0,
|
|
|
"total_steps": 0,
|
|
|
"payload": payload
|
|
|
})
|
|
|
with self.lock:
|
|
|
self.queued_task_list.append(task)
|
|
|
task.position = len(self.queued_task_list) - 1
|
|
|
self.logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
|
|
|
# create index
|
|
|
with self.lock:
|
|
|
self.task_map[task_id] = task
|
|
|
|
|
|
self._start()
|
|
|
|
|
|
return task_id
|
|
|
|
|
|
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
|
|
|
if task_id in self.task_map:
|
|
|
return self.task_map[task_id]
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
def delete_task_data(self, task_id: str) -> bool:
|
|
|
if task_id in self.task_map:
|
|
|
with self.lock:
|
|
|
del(self.task_map[task_id])
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
def stop(self):
|
|
|
self.running = False
|
|
|
self.core_running = False
|
|
|
self.queued_task_list = []
|
|
|
self.event_list = []
|
|
|
self.event_trigger.set()
|
|
|
|
|
|
# in task thread
|
|
|
def _on_task_update(self, task: TaskData, updated_value: Dict):
|
|
|
with self.event_lock:
|
|
|
self.event_list.append(task.to_dict())
|
|
|
self.event_trigger.set()
|
|
|
|
|
|
def _start(self):
|
|
|
with self.lock:
|
|
|
if self.running == False:
|
|
|
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
|
|
|
self.running = True
|
|
|
self.loop_thread.start()
|
|
|
|
|
|
def _loop(self):
|
|
|
while self.running and len(self.queued_task_list) > 0:
|
|
|
current_task = self.queued_task_list[0]
|
|
|
self.logger.info("Start task:%s." % current_task.task_id)
|
|
|
try:
|
|
|
current_task.update({
|
|
|
"status": "running",
|
|
|
"updated_time": time.time(),
|
|
|
})
|
|
|
|
|
|
# run task
|
|
|
res = current_task.callback(current_task)
|
|
|
|
|
|
# call gc
|
|
|
gc.collect()
|
|
|
torch_gc()
|
|
|
|
|
|
current_task.update({
|
|
|
"status": "finished",
|
|
|
"updated_time": time.time(),
|
|
|
"response": res,
|
|
|
"current_step": current_task.total_steps
|
|
|
})
|
|
|
except Exception as e:
|
|
|
newState = {
|
|
|
"status": "error",
|
|
|
"updated_time": time.time(),
|
|
|
"response": e
|
|
|
}
|
|
|
|
|
|
gc.collect()
|
|
|
e_s = str(e)
|
|
|
if "CUDA out of memory" in e_s or \
|
|
|
"an illegal memory access" in e_s or "CUDA" in e_s:
|
|
|
traceback.print_exc()
|
|
|
self.logger.error(str(e))
|
|
|
self.logger.error("GPU error in task.")
|
|
|
torch_gc()
|
|
|
newState["response"] = Exception("GPU error in task.")
|
|
|
# os.kill(mainpid, signal.SIGTERM)
|
|
|
|
|
|
current_task.update(newState)
|
|
|
|
|
|
with self.lock:
|
|
|
self.queued_task_list.pop(0)
|
|
|
|
|
|
# reset queue position
|
|
|
for i in range(0, len(self.queued_task_list)):
|
|
|
self.queued_task_list[i].update({
|
|
|
"position": i
|
|
|
})
|
|
|
|
|
|
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
|
|
|
self.is_busy = False
|
|
|
|
|
|
self.logger.info("Task:%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list)))
|
|
|
|
|
|
with self.lock:
|
|
|
self.running = False
|
|
|
|
|
|
# Task to remove finished task
|
|
|
def _gc_loop(self):
|
|
|
while self.core_running:
|
|
|
current_time = time.time()
|
|
|
with self.lock:
|
|
|
task_map = self.task_map.copy()
|
|
|
|
|
|
will_remove_keys = []
|
|
|
for (key, task_data) in task_map.items():
|
|
|
if task_data.status == "finished" or task_data.status == "error":
|
|
|
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
|
|
|
will_remove_keys.append(key)
|
|
|
|
|
|
if len(will_remove_keys) > 0:
|
|
|
with self.lock:
|
|
|
for key in will_remove_keys:
|
|
|
del(self.task_map[key])
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
# Event loop
|
|
|
def _event_loop(self):
|
|
|
loop = asyncio.new_event_loop()
|
|
|
asyncio.set_event_loop(loop)
|
|
|
while self.core_running:
|
|
|
self.event_trigger.wait()
|
|
|
while len(self.event_list) > 0:
|
|
|
event = self.event_list[0]
|
|
|
try:
|
|
|
self.on_update(event)
|
|
|
except Exception as err:
|
|
|
print(err)
|
|
|
|
|
|
with self.event_lock:
|
|
|
self.event_list.pop(0)
|
|
|
self.event_trigger.clear()
|
|
|
|
|
|
loop.stop()
|