pdufour commited on
Commit
f9cf2b4
·
verified ·
1 Parent(s): dd6299d

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +16 -15
index.js CHANGED
@@ -48,21 +48,6 @@ async function initializeSessions() {
48
  { executionProviders: ["webgpu"] }
49
  );
50
 
51
- ortSessionD = await ort.InferenceSession.create(
52
- await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
53
- {
54
- executionProviders: ["webgpu"],
55
- }
56
- );
57
-
58
- ortSessionE = await ort.InferenceSession.create(
59
- await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
60
- {
61
- executionProviders: ["webgpu"],
62
- },
63
- );
64
-
65
-
66
  config = (await getModelJSON(BASE_MODEL, "config.json"));
67
 
68
  status.textContent = 'Ready';
@@ -255,6 +240,13 @@ export async function imageTextToText(
255
 
256
  await ortSessionA.release();
257
  ortSessionA = null;
 
 
 
 
 
 
 
258
 
259
  ({ hidden_states, position_ids } = await ortSessionD.run({
260
  "hidden_states.1": hidden_states,
@@ -276,6 +268,15 @@ export async function imageTextToText(
276
  ) {
277
  let token_id;
278
 
 
 
 
 
 
 
 
 
 
279
  ({
280
  max_logit_ids: token_id,
281
  past_key_states: past_key_states,
 
48
  { executionProviders: ["webgpu"] }
49
  );
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  config = (await getModelJSON(BASE_MODEL, "config.json"));
52
 
53
  status.textContent = 'Ready';
 
240
 
241
  await ortSessionA.release();
242
  ortSessionA = null;
243
+
244
+ ortSessionD = await ort.InferenceSession.create(
245
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
246
+ {
247
+ executionProviders: ["webgpu"],
248
+ }
249
+ );
250
 
251
  ({ hidden_states, position_ids } = await ortSessionD.run({
252
  "hidden_states.1": hidden_states,
 
268
  ) {
269
  let token_id;
270
 
271
+ if (!ortSessionE) {
272
+ ortSessionE = await ort.InferenceSession.create(
273
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
274
+ {
275
+ executionProviders: ["webgpu"],
276
+ },
277
+ );
278
+ }
279
+
280
  ({
281
  max_logit_ids: token_id,
282
  past_key_states: past_key_states,