|
|
|
@ -1133,6 +1133,8 @@ export class RemoteImageGenerationRequest {
|
|
|
|
|
onError: (err: { status: number; message: string }) => void,
|
|
|
|
|
onClose: () => void
|
|
|
|
|
): Promise<void> {
|
|
|
|
|
const MAX_RETRY_TIMES = 3;
|
|
|
|
|
|
|
|
|
|
const requestStartGeneration: RequestInit = {
|
|
|
|
|
mode: 'cors',
|
|
|
|
|
cache: 'no-store',
|
|
|
|
@ -1175,72 +1177,156 @@ export class RemoteImageGenerationRequest {
|
|
|
|
|
let res = await bind.json()
|
|
|
|
|
const taskId = res.task_id
|
|
|
|
|
|
|
|
|
|
const minDelay = 2000;
|
|
|
|
|
while (true) {
|
|
|
|
|
let reqStartTime = new Date().getTime();
|
|
|
|
|
const requestGetTask: RequestInit = {
|
|
|
|
|
mode: 'cors',
|
|
|
|
|
cache: 'no-store',
|
|
|
|
|
headers: {
|
|
|
|
|
'Content-Type': 'application/json',
|
|
|
|
|
Authorization: 'Bearer ' + this.user.auth_token,
|
|
|
|
|
},
|
|
|
|
|
method: 'POST',
|
|
|
|
|
body: JSON.stringify({
|
|
|
|
|
task_id: taskId
|
|
|
|
|
}),
|
|
|
|
|
const handleTaskInfoUpdate = (taskInfo: any): boolean => {
|
|
|
|
|
let progress = 0;
|
|
|
|
|
|
|
|
|
|
if (taskInfo.status === "finished") {
|
|
|
|
|
onProgress(100, 0)
|
|
|
|
|
return true
|
|
|
|
|
} else if (taskInfo.status === "error") {
|
|
|
|
|
throw new Error("Remote error")
|
|
|
|
|
} else if (taskInfo.status === "running") {
|
|
|
|
|
if (typeof taskInfo.current_step === "number" && typeof taskInfo.total_steps === "number" && taskInfo.total_steps > 0) {
|
|
|
|
|
progress = Math.min(Math.round(taskInfo.current_step / taskInfo.total_steps * 100), 100)
|
|
|
|
|
onProgress(progress, 0)
|
|
|
|
|
}
|
|
|
|
|
} else if (taskInfo.status === "queued") {
|
|
|
|
|
onProgress(0, taskInfo.position)
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let res: any = {};
|
|
|
|
|
if ('WebSocket' in window) {
|
|
|
|
|
try {
|
|
|
|
|
const bind = await fetchWithTimeout(BackendURLGetTaskInfo, requestGetTask)
|
|
|
|
|
if (!bind.ok) {
|
|
|
|
|
logError(bind, false)
|
|
|
|
|
|
|
|
|
|
let errorData = await bind.json()
|
|
|
|
|
|
|
|
|
|
onError({
|
|
|
|
|
status: bind.status ?? 500,
|
|
|
|
|
message: errorData.error,
|
|
|
|
|
})
|
|
|
|
|
await new Promise<void>((resolve, reject) => {
|
|
|
|
|
const wsUrl = new URL(BackendURLGetTaskInfo, location.href.replace(/^http/, 'ws'))
|
|
|
|
|
wsUrl.search = '?task_id=' + encodeURIComponent(taskId)
|
|
|
|
|
|
|
|
|
|
const wsConnect = () => {
|
|
|
|
|
let ws: WebSocket | undefined = undefined;
|
|
|
|
|
let willClose = false;
|
|
|
|
|
let hasError = false;
|
|
|
|
|
|
|
|
|
|
ws = new WebSocket(wsUrl.href)
|
|
|
|
|
|
|
|
|
|
let heartbeatTimer: NodeJS.Timer | undefined = setInterval(() => {
|
|
|
|
|
if (ws.readyState === ws.OPEN) {
|
|
|
|
|
ws.send('ping')
|
|
|
|
|
}
|
|
|
|
|
}, 15000)
|
|
|
|
|
ws.addEventListener('error', () => {
|
|
|
|
|
hasError = true
|
|
|
|
|
// reconnect on error
|
|
|
|
|
sleep(2000).then(() => {
|
|
|
|
|
wsConnect()
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
ws.addEventListener('close', () => {
|
|
|
|
|
if (heartbeatTimer) {
|
|
|
|
|
clearInterval(heartbeatTimer)
|
|
|
|
|
heartbeatTimer = undefined
|
|
|
|
|
}
|
|
|
|
|
if (!willClose && !hasError) {
|
|
|
|
|
// reconnect on abnormal disconnect
|
|
|
|
|
sleep(2000).then(() => {
|
|
|
|
|
wsConnect()
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
ws.addEventListener('message', (event) => {
|
|
|
|
|
if (!event.data) return
|
|
|
|
|
try {
|
|
|
|
|
let data = JSON.parse(event.data)
|
|
|
|
|
if (data.error && data.error === 'ERR::TASK_NOT_FOUND') {
|
|
|
|
|
onError({ status: 500, message: 'Task not found, maybe lost.' })
|
|
|
|
|
willClose = true
|
|
|
|
|
ws.close()
|
|
|
|
|
resolve()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if (handleTaskInfoUpdate(data)) {
|
|
|
|
|
willClose = true
|
|
|
|
|
ws.close()
|
|
|
|
|
resolve()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
} catch(err) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
wsConnect()
|
|
|
|
|
})
|
|
|
|
|
} catch(err) {
|
|
|
|
|
if (err.message === "Remote error") {
|
|
|
|
|
onError({ status: 500, message: 'Internal server error occurred. Please try again later' })
|
|
|
|
|
return
|
|
|
|
|
} else {
|
|
|
|
|
throw err;
|
|
|
|
|
}
|
|
|
|
|
res = await bind.json()
|
|
|
|
|
} catch (err) {
|
|
|
|
|
// don't stop generate when http error
|
|
|
|
|
}
|
|
|
|
|
// console.log('task info', res)
|
|
|
|
|
} else {
|
|
|
|
|
const minDelay = 2000;
|
|
|
|
|
while (true) {
|
|
|
|
|
let reqStartTime = new Date().getTime();
|
|
|
|
|
const requestGetTask: RequestInit = {
|
|
|
|
|
mode: 'cors',
|
|
|
|
|
cache: 'no-store',
|
|
|
|
|
headers: {
|
|
|
|
|
'Content-Type': 'application/json',
|
|
|
|
|
Authorization: 'Bearer ' + this.user.auth_token,
|
|
|
|
|
},
|
|
|
|
|
method: 'POST',
|
|
|
|
|
body: JSON.stringify({
|
|
|
|
|
task_id: taskId
|
|
|
|
|
}),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let progress = 0;
|
|
|
|
|
if (res.status === "finished") {
|
|
|
|
|
onProgress(100, 0)
|
|
|
|
|
break;
|
|
|
|
|
} else if (res.status === "error") {
|
|
|
|
|
break;
|
|
|
|
|
} else if (res.status === "running") {
|
|
|
|
|
if (typeof res.current_step === "number" && typeof res.total_steps === "number" && res.total_steps > 0) {
|
|
|
|
|
progress = Math.min(Math.round(res.current_step / res.total_steps * 100), 100)
|
|
|
|
|
onProgress(progress, 0)
|
|
|
|
|
let res: any = {};
|
|
|
|
|
try {
|
|
|
|
|
const bind = await fetchWithTimeout(BackendURLGetTaskInfo, requestGetTask)
|
|
|
|
|
if (!bind.ok) {
|
|
|
|
|
logError(bind, false)
|
|
|
|
|
|
|
|
|
|
let errorData = await bind.json()
|
|
|
|
|
|
|
|
|
|
if (errorData.code === 'ERR::TASK_NOT_FOUND') {
|
|
|
|
|
onError({ status: 500, message: 'Task not found, maybe lost.' })
|
|
|
|
|
} else {
|
|
|
|
|
onError({
|
|
|
|
|
status: bind.status ?? 500,
|
|
|
|
|
message: errorData.error,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
res = await bind.json()
|
|
|
|
|
} catch (err) {
|
|
|
|
|
// don't stop generate when http error
|
|
|
|
|
}
|
|
|
|
|
// console.log('task info', res)
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
if (handleTaskInfoUpdate(res)) {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
} catch(err) {
|
|
|
|
|
if (err.message === "Remote error") {
|
|
|
|
|
onError({ status: 500, message: 'Internal server error occurred. Please try again later' })
|
|
|
|
|
return
|
|
|
|
|
} else {
|
|
|
|
|
throw err;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (res.status === "queued") {
|
|
|
|
|
onProgress(0, res.position)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let reqEndTime = new Date().getTime();
|
|
|
|
|
if (reqEndTime - reqStartTime < minDelay) {
|
|
|
|
|
await sleep(minDelay - (reqEndTime - reqStartTime))
|
|
|
|
|
let reqEndTime = new Date().getTime();
|
|
|
|
|
if (reqEndTime - reqStartTime < minDelay) {
|
|
|
|
|
await sleep(minDelay - (reqEndTime - reqStartTime))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const timeout = setTimeout(() => {
|
|
|
|
|
source.close()
|
|
|
|
|
onError({
|
|
|
|
|
status: 408,
|
|
|
|
|
message:
|
|
|
|
|
'Error: Timeout - Unable to reach Naifu servers. Please wait for a moment and try again',
|
|
|
|
|
})
|
|
|
|
|
}, 30 * 1000)
|
|
|
|
|
|
|
|
|
|
const requestGetGenerationOutput: RequestInit = {
|
|
|
|
|
mode: 'cors',
|
|
|
|
|
cache: 'no-store',
|
|
|
|
@ -1259,28 +1345,50 @@ export class RemoteImageGenerationRequest {
|
|
|
|
|
}*/),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const source = new SSE(BackendURLGetGenerateImageOutput, {
|
|
|
|
|
headers: requestGetGenerationOutput.headers,
|
|
|
|
|
payload: requestGetGenerationOutput.body
|
|
|
|
|
})
|
|
|
|
|
source.addEventListener('newImage', (message: any) => {
|
|
|
|
|
clearTimeout(timeout)
|
|
|
|
|
onImage(Buffer.from(message.data, 'base64'), message.id)
|
|
|
|
|
})
|
|
|
|
|
source.addEventListener('error', (err: any) => {
|
|
|
|
|
clearTimeout(timeout)
|
|
|
|
|
source.close()
|
|
|
|
|
onError({
|
|
|
|
|
status: err.detail.statusCode ?? 'unknown status',
|
|
|
|
|
message: err.detail.message || err.detail.error,
|
|
|
|
|
let retryTimes = 0;
|
|
|
|
|
|
|
|
|
|
const requestOutput = () => {
|
|
|
|
|
const timeout = setTimeout(() => {
|
|
|
|
|
source.close()
|
|
|
|
|
onError({
|
|
|
|
|
status: 408,
|
|
|
|
|
message:
|
|
|
|
|
'Error: Timeout - Unable to reach Naifu servers. Please wait for a moment and try again',
|
|
|
|
|
})
|
|
|
|
|
}, 30 * 1000)
|
|
|
|
|
|
|
|
|
|
const source = new SSE(BackendURLGetGenerateImageOutput, {
|
|
|
|
|
headers: requestGetGenerationOutput.headers,
|
|
|
|
|
payload: requestGetGenerationOutput.body
|
|
|
|
|
})
|
|
|
|
|
logWarning(err, true, 'streaming error')
|
|
|
|
|
})
|
|
|
|
|
source.addEventListener('readystatechange', (e: any) => {
|
|
|
|
|
if (source.readyState === 2) {
|
|
|
|
|
onClose()
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
source.stream()
|
|
|
|
|
source.addEventListener('newImage', (message: any) => {
|
|
|
|
|
clearTimeout(timeout)
|
|
|
|
|
onImage(Buffer.from(message.data, 'base64'), message.id)
|
|
|
|
|
})
|
|
|
|
|
source.addEventListener('error', async (err: any) => {
|
|
|
|
|
clearTimeout(timeout)
|
|
|
|
|
source.close()
|
|
|
|
|
|
|
|
|
|
logWarning(err, true, 'streaming error')
|
|
|
|
|
if (retryTimes < MAX_RETRY_TIMES) { // Should retry
|
|
|
|
|
retryTimes ++
|
|
|
|
|
await sleep(2000)
|
|
|
|
|
requestOutput()
|
|
|
|
|
} else {
|
|
|
|
|
onError({
|
|
|
|
|
status: err.detail.statusCode ?? 'unknown status',
|
|
|
|
|
message: err.detail.message || err.detail.error,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
source.addEventListener('readystatechange', (e: any) => {
|
|
|
|
|
if (source.readyState === 2) {
|
|
|
|
|
onClose()
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
source.stream()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
requestOutput()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|