Skip to content

Commit

Permalink
Fix requestAnimationFrame not exiting the loop properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
webees committed Sep 16, 2023
1 parent bfc4b38 commit ff14bab
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import { showLoadingToast, closeToast } from 'vant'
import i18n from '@/vue-i18n'
import { yolo } from '@/vue-pinia'
import { loadModel } from '@/utils/tf'
import { loadModel } from '@/composables/yolo'
// title
const route = useRoute()
Expand Down
49 changes: 16 additions & 33 deletions src/utils/tf.ts → src/composables/yolo.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import type { GraphModel, io, Rank, Tensor, Tensor1D, Tensor2D, Tensor3D } from '@tensorflow/tfjs'
import type { Rank, Tensor, Tensor1D, Tensor2D, Tensor3D } from '@tensorflow/tfjs'
import * as tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-backend-webgl'
import { yolo } from '@/vue-pinia'

import { renderBoxes } from '@/utils/renderBox'
import labels from '@/utils/labels.json'

const numClass = labels.length
import { renderBoxes } from '@/utils/renderBox'

export function loadModel() {
tf.ready().then(async () => {
Expand Down Expand Up @@ -70,27 +67,22 @@ function preprocess(source: HTMLVideoElement | HTMLImageElement, modelWidth: num

/**
* Function run inference and do detection from source.
* @param {tf.GraphModel} model loaded YOLOv8 tensorflow.js model
* @param {number[]} inputShape
* @param {HTMLImageElement|HTMLVideoElement} source
* @param {HTMLCanvasElement} canvasRef canvas reference
* @param {VoidFunction} callback function to run after detection process
*/
export async function detect(
model: GraphModel<string | io.IOHandler>,
inputShape: number[],
source: HTMLImageElement | HTMLVideoElement,
canvasRef: HTMLCanvasElement,
callback: () => void
) {
export async function detect(source: HTMLImageElement | HTMLVideoElement, canvasRef: HTMLCanvasElement, callback: () => void) {
const model = toRaw(yolo().model)
const inputShape = toRaw(yolo().inputShape)

tf.engine().startScope() // start scoping tf engine
const [modelWidth, modelHeight] = inputShape.slice(1, 3) // get model width and height
// console.log('shape', modelWidth, modelHeight)

const [input, xRatio, yRatio] = preprocess(source, modelWidth, modelHeight) // preprocess image
// console.log('ratio', xRatio, yRatio)

const res = toRaw(model).execute(input) as Tensor<Rank> // Must use toRaw() inference model.
const res = model!.execute(input) as Tensor<Rank> // Must use toRaw() inference model.
const transRes = res.transpose([0, 2, 1]) // transpose result [b, det, n] => [b, n, det]

const boxes = tf.tidy(() => {
Expand All @@ -112,12 +104,11 @@ export async function detect(
}) as Tensor2D // process boxes [y1, x1, y2, x2]

const [scores, classes] = tf.tidy(() => {
const rawScores = transRes.slice([0, 0, 4], [-1, -1, numClass]).squeeze() // #6 only squeeze axis 0 to handle only 1 class models
const rawScores = transRes.slice([0, 0, 4], [-1, -1, labels.length]).squeeze() // #6 only squeeze axis 0 to handle only 1 class models
return [rawScores.max(1), rawScores.argMax(1)]
}) as [Tensor1D, Tensor2D] // get max scores and classes index

const nms = await tf.image.nonMaxSuppressionAsync(boxes, scores, 500, 0.45, 0.2) // NMS to filter boxes

const boxes_data = boxes.gather(nms, 0).dataSync() // indexing boxes by nms index
const scores_data = scores.gather(nms, 0).dataSync() // indexing scores by nms index
const classes_data = classes.gather(nms, 0).dataSync() // indexing classes by nms index
Expand All @@ -131,34 +122,26 @@ export async function detect(
}

/**
* Function to detect video from every source.
* Function to detect every frame from video
* @param {HTMLVideoElement} source video source
* @param {tf.GraphModel} model loaded YOLOv8 tensorflow.js model
* @param {HTMLCanvasElement} canvasRef canvas reference
*/
export function detectVideo(model: GraphModel<string | io.IOHandler>, inputShape: number[], source: HTMLVideoElement, canvasRef: HTMLCanvasElement) {
/**
* Function to detect every frame from video
*/
export function detectVideo(source: HTMLVideoElement, canvasRef: HTMLCanvasElement) {
let animationId = -1
const detectFrame = async () => {
if (source.videoWidth === 0 && source.srcObject === null) {
console.warn('source.srcObject === null')
if (source.paused) {
console.log('source.paused', source.paused)
const ctx = canvasRef.getContext('2d')
ctx && ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height) // clean canvas
ctx && ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height)
cancelAnimationFrame(animationId)
console.warn('cancelAnimationFrame', animationId)
return // handle if source is closed
}

detect(model, inputShape, source, canvasRef, () => {
detect(source, canvasRef, () => {
animationId = requestAnimationFrame(detectFrame) // get another frame
})
}

detectFrame() // initialize to detect every frame
return animationId
}

export function unDetectVideo(id: number) {
cancelAnimationFrame(id)
console.warn('unDetectVideo', id)
}
11 changes: 5 additions & 6 deletions src/views/image.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@
<span class="absolute w-full h-full flex justify-center items-center">
<van-uploader :after-read="afterRead" reupload max-count="1">
<canvas v-if="imageUrl" ref="canvasRef" class="absolute top-0 w-full h-full" :width="yolo().inputShape[1]" :height="yolo().inputShape[2]" />
<img v-if="imageUrl" ref="imageRef" :src="imageUrl" @load="onImageLoadDetect" class="w-full h-full" />
<img v-if="imageUrl" ref="imageRef" class="w-full h-full" :src="imageUrl" @load="onImageLoad" />
<van-button v-else icon="plus" type="primary">{{ i18n.t('Open Image') }}</van-button>
</van-uploader>
</span>
</template>

<script lang="ts" setup>
import type { GraphModel, io } from '@tensorflow/tfjs'
import { UploaderFileListItem } from 'vant'
import { detect } from '@/utils/tf'
import { detect } from '@/composables/yolo'
import i18n from '@/vue-i18n'
import { yolo } from '@/vue-pinia'
const imageUrl = ref()
const imageRef = ref()
const imageRef = ref<HTMLImageElement>()
const canvasRef = ref()
// uploader
function afterRead(file: UploaderFileListItem | UploaderFileListItem[]) {
imageUrl.value = (file as UploaderFileListItem).objectUrl
}
function onImageLoadDetect() {
detect(yolo().model as GraphModel<string | io.IOHandler>, yolo().inputShape, imageRef.value, canvasRef.value, () => {
function onImageLoad() {
detect(imageRef.value as HTMLImageElement, canvasRef.value, () => {
console.log('detect done')
})
}
Expand Down
17 changes: 8 additions & 9 deletions src/views/video.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,33 @@
<span class="absolute w-full h-full flex justify-center items-center">
<van-uploader :after-read="afterRead" reupload max-count="1" accept="video/*">
<canvas v-if="videoUrl" ref="canvasRef" class="absolute top-0 w-full h-full" :width="yolo().inputShape[1]" :height="yolo().inputShape[2]" />
<video v-if="videoUrl" ref="videoRef" class="w-full h-full" :src="videoUrl" controls autoplay @play="onVideoPlayDetect" />
<video v-if="videoUrl" ref="videoRef" class="w-full h-full" :src="videoUrl" @play="onVideoPlay" controls muted autoplay />
<van-button v-else icon="plus" type="primary">{{ i18n.t('Open Video') }}</van-button>
</van-uploader>
</span>
</template>

<script lang="ts" setup>
import type { GraphModel, io } from '@tensorflow/tfjs'
import { UploaderFileListItem } from 'vant'
import { detectVideo, unDetectVideo } from '@/utils/tf'
import { detectVideo } from '@/composables/yolo'
import i18n from '@/vue-i18n'
import { yolo } from '@/vue-pinia'
const videoUrl = ref()
const videoRef = ref()
const videoUrl = ref<string>()
const videoRef = ref<HTMLVideoElement>()
const canvasRef = ref()
let videoID = ref()
// uploader
function afterRead(file: UploaderFileListItem | UploaderFileListItem[]) {
videoUrl.value = (file as UploaderFileListItem).objectUrl
}
function onVideoPlayDetect() {
videoID.value = detectVideo(yolo().model as GraphModel<string | io.IOHandler>, yolo().inputShape, videoRef.value, canvasRef.value)
function onVideoPlay() {
detectVideo(videoRef.value as HTMLVideoElement, canvasRef.value)
}
onBeforeUnmount(() => {
unDetectVideo(videoID.value)
videoUrl.value = ''
videoRef.value!.pause()
})
</script>

Expand Down

0 comments on commit ff14bab

Please sign in to comment.