You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

264 lines
8.1 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import gc
from pydoc import isdata
import random
import string
import threading
import time
import traceback
from types import LambdaType
from typing import Dict, Optional, List, Any, Union
from typing_extensions import TypedDict
from async_timeout import asyncio
from hydra_node.models import torch_gc
class TaskQueueFullException(Exception):
pass
class TaskDataType(TypedDict):
task_id: str
position: int
status: str
updated_time: float
callback: LambdaType
current_step: int
total_steps: int
payload: Any
response: Any
class TaskData:
_task_queue = None
task_id: str = ""
position: int = 0
status: str = "preparing"
updated_time: float = 0
callback: LambdaType = lambda x: x
current_step: int = 0
total_steps: int = 0
payload: Any = {}
response: Any = None
def __init__(self, taskQueue, initData: Dict = {}):
self._task_queue = taskQueue
self.update(initData, skipEvent=True)
def update(self, data: Dict, skipEvent: bool = False):
for (key, value) in data.items():
if key[0] != "_" and hasattr(self, key):
setattr(self, key, value)
else:
raise Exception("Property %s not found on TaskData" % (key))
if skipEvent == False:
self._task_queue._on_task_update(self, data)
def to_dict(self):
allowedItems = ["task_id", "position", "status", "updated_time", "current_step", "total_steps"]
retDict = {}
for itemKey in allowedItems:
retDict[itemKey] = getattr(self, itemKey)
return retDict
class TaskQueue:
logger = None
max_queue_size = 10
recovery_queue_size = 6
is_busy = False
loop_thread = None
gc_loop_thread = None
event_loop_thread = None
running = False
core_running = False
lock = threading.Lock()
event_lock = threading.Lock()
event_trigger = threading.Event()
event_list = []
queued_task_list: List[TaskData] = []
task_map: Dict[str, TaskData] = {}
on_update = lambda queue_data: None
def __init__(self, logger, max_queue_size = 10, recovery_queue_size = 6) -> None:
self.logger = logger
self.max_queue_size = max_queue_size
self.max_queue_size = recovery_queue_size
self.core_running = True
self.gc_loop_thread = threading.Thread(name="TaskQueueGC", target=self._gc_loop)
self.gc_loop_thread.start()
self.event_loop_thread = threading.Thread(name="TaskQueueEvent", target=self._event_loop)
self.event_loop_thread.start()
logger.info("Task queue created")
def add_task(self, callback: LambdaType, payload: Any = {}) -> str:
if self.is_busy:
raise TaskQueueFullException("Task queue is full")
if len(self.queued_task_list) >= self.max_queue_size: # mark busy
self.is_busy = True
raise TaskQueueFullException("Task queue is full")
task_id = ''.join(random.sample(string.ascii_letters + string.digits, 16))
task = TaskData(self, {
"task_id": task_id,
"position": 0,
"status": "queued",
"updated_time": time.time(),
"callback": callback,
"current_step": 0,
"total_steps": 0,
"payload": payload
})
with self.lock:
self.queued_task_list.append(task)
task.position = len(self.queued_task_list) - 1
self.logger.info("Added task: %s, queue size: %d" % (task_id, len(self.queued_task_list)))
# create index
with self.lock:
self.task_map[task_id] = task
self._start()
return task_id
def get_task_data(self, task_id: str) -> Union[TaskData, bool]:
if task_id in self.task_map:
return self.task_map[task_id]
else:
return False
def delete_task_data(self, task_id: str) -> bool:
if task_id in self.task_map:
with self.lock:
del(self.task_map[task_id])
else:
return False
def stop(self):
self.running = False
self.core_running = False
self.queued_task_list = []
self.event_list = []
self.event_trigger.set()
# in task thread
def _on_task_update(self, task: TaskData, updated_value: Dict):
with self.event_lock:
self.event_list.append(task.to_dict())
self.event_trigger.set()
def _start(self):
with self.lock:
if self.running == False:
self.loop_thread = threading.Thread(name="TaskQueue", target=self._loop)
self.running = True
self.loop_thread.start()
def _loop(self):
while self.running and len(self.queued_task_list) > 0:
current_task = self.queued_task_list[0]
self.logger.info("Start task%s." % current_task.task_id)
try:
current_task.update({
"status": "running",
"updated_time": time.time(),
})
# run task
res = current_task.callback(current_task)
# call gc
gc.collect()
torch_gc()
current_task.update({
"status": "finished",
"updated_time": time.time(),
"response": res,
"current_step": current_task.total_steps
})
except Exception as e:
newState = {
"status": "error",
"updated_time": time.time(),
"response": e
}
gc.collect()
e_s = str(e)
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
traceback.print_exc()
self.logger.error(str(e))
self.logger.error("GPU error in task.")
torch_gc()
newState["response"] = Exception("GPU error in task.")
# os.kill(mainpid, signal.SIGTERM)
current_task.update(newState)
with self.lock:
self.queued_task_list.pop(0)
# reset queue position
for i in range(0, len(self.queued_task_list)):
self.queued_task_list[i].update({
"position": i
})
if self.is_busy and len(self.queued_task_list) < self.recovery_queue_size: # mark not busy
self.is_busy = False
self.logger.info("Task%s finished, queue size: %d" % (current_task.task_id, len(self.queued_task_list)))
with self.lock:
self.running = False
# Task to remove finished task
def _gc_loop(self):
while self.core_running:
current_time = time.time()
with self.lock:
task_map = self.task_map.copy()
will_remove_keys = []
for (key, task_data) in task_map.items():
if task_data.status == "finished" or task_data.status == "error":
if task_data.updated_time + 120 < current_time: # delete task which finished 2 minutes ago
will_remove_keys.append(key)
if len(will_remove_keys) > 0:
with self.lock:
for key in will_remove_keys:
del(self.task_map[key])
time.sleep(1)
# Event loop
def _event_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while self.core_running:
self.event_trigger.wait()
while len(self.event_list) > 0:
event = self.event_list[0]
try:
self.on_update(event)
except Exception as err:
print(err)
with self.event_lock:
self.event_list.pop(0)
self.event_trigger.clear()
loop.stop()