diff --git a/app/src/main/java/com/donut/mixfile/server/core/routes/api/DownloadRoute.kt b/app/src/main/java/com/donut/mixfile/server/core/routes/api/DownloadRoute.kt index beac34f..305b6d1 100644 --- a/app/src/main/java/com/donut/mixfile/server/core/routes/api/DownloadRoute.kt +++ b/app/src/main/java/com/donut/mixfile/server/core/routes/api/DownloadRoute.kt @@ -17,6 +17,7 @@ import io.ktor.server.response.header import io.ktor.server.response.respondBytesWriter import io.ktor.server.response.respondText import io.ktor.server.routing.RoutingHandler +import io.ktor.utils.io.ByteWriteChannel import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import kotlinx.coroutines.Deferred @@ -79,67 +80,64 @@ suspend fun MixFileServer.respondMixFile(call: ApplicationCall, shareInfo: MixSh } contentLength = mixFile.fileSize - range.first } - responseDownloadFileStream( - call = call, - fileList = fileList, - contentLength = contentLength, - shareInfo = shareInfo, - mixFile = mixFile, - referer = referer, - name = name - ) + call.respondBytesWriter( + contentType = name.parseFileMimeType(), + contentLength = contentLength + ) { + writeMixFileToByteChannel( + shareInfo = shareInfo, + mixFile = mixFile, + fileList = fileList, + referer = referer, + channel = this + ) + } + } -private suspend fun MixFileServer.responseDownloadFileStream( - call: ApplicationCall, - fileList: List>, - contentLength: Long, +suspend fun MixFileServer.writeMixFileToByteChannel( shareInfo: MixShareInfo, mixFile: MixFile, + fileList: List> = mixFile.fileList.map { it to 0 }, referer: String = shareInfo.referer, - name: String = shareInfo.fileName + channel: ByteWriteChannel, ) { coroutineScope { val chunkSize = mixFile.chunkSize val chunkSizeMB = chunkSize / 1.mb val taskCount = downloadTaskCount / chunkSizeMB.coerceAtLeast(1) val fileListToWrite = fileList.toMutableList() - call.respondBytesWriter( - contentType = name.parseFileMimeType(), - contentLength = contentLength - ) { - val sortedTask = SortedTask(taskCount.coerceAtLeast(1)) - val tasks = mutableListOf>() - while (!isClosedForWrite && fileListToWrite.isNotEmpty()) { - val currentFile = fileListToWrite.removeAt(0) - val taskOrder = -fileListToWrite.size - sortedTask.prepareTask(taskOrder) - tasks.add(async { - val (url, range) = currentFile - val dataBytes = try { - shareInfo.fetchFile(url, httpClient, referer) + val sortedTask = SortedTask(taskCount.coerceAtLeast(1)) + val tasks = mutableListOf>() + while (!channel.isClosedForWrite && fileListToWrite.isNotEmpty()) { + val currentFile = fileListToWrite.removeAt(0) + val taskOrder = -fileListToWrite.size + sortedTask.prepareTask(taskOrder) + tasks.add(async { + val (url, range) = currentFile + val dataBytes = try { + shareInfo.fetchFile(url, httpClient, referer) + } catch (e: Exception) { + channel.close(e) + throw e + } + sortedTask.addTask(taskOrder) { + val dataToWrite = when { + range == 0 -> dataBytes + range < 0 -> dataBytes.copyOfRange(0, -range) //一般无 < 0 的情况 + else -> dataBytes.copyOfRange(range, dataBytes.size) + } + try { + channel.writeFully(dataToWrite) + onDownloadData(dataToWrite) } catch (e: Exception) { - close(e) + channel.close(e) throw e } - sortedTask.addTask(taskOrder) { - val dataToWrite = when { - range == 0 -> dataBytes - range < 0 -> dataBytes.copyOfRange(0, -range) //一般无 < 0 的情况 - else -> dataBytes.copyOfRange(range, dataBytes.size) - } - try { - writeFully(dataToWrite) - onDownloadData(dataToWrite) - } catch (e: Exception) { - close(e) - throw e - } - } - sortedTask.execute() - }) - } - tasks.awaitAll() + } + sortedTask.execute() + }) } + tasks.awaitAll() } } \ No newline at end of file