Skip to content

Commit 0c655c8

Browse files
Merge branch 'master' into dag.support
2 parents 54d3b9f + ab65392 commit 0c655c8

File tree

2 files changed

+131
-5
lines changed

2 files changed

+131
-5
lines changed

src/model.ts

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,32 @@ export class Model {
1111
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow models)
1212
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow models)
1313
* @param blob - the Protobuf-serialized model
14+
* @param batchsize - when provided with an batchsize that is greater than 0, the engine will batch incoming requests from multiple clients that use the model with input tensors of the same shape.
15+
* @param minbatchsize - when provided with an minbatchsize that is greater than 0, the engine will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
1416
*/
15-
constructor(backend: Backend, device: string, inputs: string[], outputs: string[], blob: Buffer | undefined) {
17+
constructor(
18+
backend: Backend,
19+
device: string,
20+
inputs: string[],
21+
outputs: string[],
22+
blob: Buffer | undefined,
23+
batchsize?: number,
24+
minbatchsize?: number,
25+
) {
1626
this._backend = backend;
1727
this._device = device;
1828
this._inputs = inputs;
1929
this._outputs = outputs;
2030
this._blob = blob;
2131
this._tag = undefined;
32+
this._batchsize = batchsize || 0;
33+
if (this._batchsize < 0) {
34+
this._batchsize = 0;
35+
}
36+
this._minbatchsize = minbatchsize || 0;
37+
if (this._minbatchsize < 0) {
38+
this._minbatchsize = 0;
39+
}
2240
}
2341

2442
// tag is an optional string for tagging the model such as a version number or any arbitrary identifier
@@ -86,14 +104,39 @@ export class Model {
86104
this._blob = value;
87105
}
88106

107+
private _batchsize: number;
108+
109+
get batchsize(): number {
110+
return this._batchsize;
111+
}
112+
113+
set batchsize(value: number) {
114+
this._batchsize = value;
115+
}
116+
117+
private _minbatchsize: number;
118+
119+
get minbatchsize(): number {
120+
return this._minbatchsize;
121+
}
122+
123+
set minbatchsize(value: number) {
124+
this._minbatchsize = value;
125+
}
126+
89127
static NewModelFromModelGetReply(reply: any[]) {
90128
let backend = null;
91129
let device = null;
92130
let tag = null;
93131
let blob = null;
132+
let batchsize: number = 0;
133+
let minbatchsize: number = 0;
134+
const inputs: string[] = [];
135+
const outputs: string[] = [];
94136
for (let i = 0; i < reply.length; i += 2) {
95137
const key = reply[i];
96138
const obj = reply[i + 1];
139+
97140
switch (key.toString()) {
98141
case 'backend':
99142
backend = BackendMap[obj.toString()];
@@ -106,9 +149,20 @@ export class Model {
106149
tag = obj.toString();
107150
break;
108151
case 'blob':
109-
// blob = obj;
110152
blob = Buffer.from(obj);
111153
break;
154+
case 'batchsize':
155+
batchsize = parseInt(obj.toString(), 10);
156+
break;
157+
case 'minbatchsize':
158+
minbatchsize = parseInt(obj.toString(), 10);
159+
break;
160+
case 'inputs':
161+
obj.forEach((input) => inputs.push(input));
162+
break;
163+
case 'outputs':
164+
obj.forEach((output) => outputs.push(output));
165+
break;
112166
}
113167
}
114168
if (backend == null || device == null || blob == null) {
@@ -126,7 +180,7 @@ export class Model {
126180
'AI.MODELGET reply did not had the full elements to build the Model. Missing ' + missingArr.join(',') + '.',
127181
);
128182
}
129-
const model = new Model(backend, device, [], [], blob);
183+
const model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
130184
if (tag !== null) {
131185
model.tag = tag;
132186
}
@@ -151,6 +205,14 @@ export class Model {
151205
args.push('TAG');
152206
args.push(this.tag.toString());
153207
}
208+
if (this.batchsize > 0) {
209+
args.push('BATCHSIZE');
210+
args.push(this.batchsize);
211+
if (this.minbatchsize > 0) {
212+
args.push('MINBATCHSIZE');
213+
args.push(this.minbatchsize);
214+
}
215+
}
154216
if (this.inputs.length > 0) {
155217
args.push('INPUTS');
156218
this.inputs.forEach((value) => args.push(value));

tests/test_client.ts

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,77 @@ it(
331331
const aiclient = new Client(nativeClient);
332332

333333
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
334-
const model = new Model(Backend.TF, 'CPU', ['a', 'b'], ['c'], modelBlob);
334+
const inputs: string[] = ['a', 'b'];
335+
const outputs: string[] = ['c'];
336+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
335337
model.tag = 'test_tag';
336338
const resultModelSet = await aiclient.modelset('mymodel', model);
337339
expect(resultModelSet).to.equal('OK');
338340

339-
const modelOut = await aiclient.modelget('mymodel');
341+
const modelOut: Model = await aiclient.modelget('mymodel');
340342
expect(modelOut.blob.toString()).to.equal(modelBlob.toString());
343+
for (let index = 0; index < modelOut.outputs.length; index++) {
344+
expect(modelOut.outputs[index]).to.equal(outputs[index]);
345+
expect(modelOut.outputs[index]).to.equal(model.outputs[index]);
346+
}
347+
for (let index = 0; index < modelOut.inputs.length; index++) {
348+
expect(modelOut.inputs[index]).to.equal(inputs[index]);
349+
expect(modelOut.inputs[index]).to.equal(model.inputs[index]);
350+
}
351+
expect(modelOut.batchsize).to.equal(model.batchsize);
352+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
353+
aiclient.end(true);
354+
}),
355+
);
356+
357+
it(
358+
'ai.modelget batching positive testing',
359+
mochaAsync(async () => {
360+
const nativeClient = createClient();
361+
const aiclient = new Client(nativeClient);
362+
363+
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
364+
const inputs: string[] = ['a', 'b'];
365+
const outputs: string[] = ['c'];
366+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
367+
model.tag = 'test_tag';
368+
model.batchsize = 100;
369+
model.minbatchsize = 5;
370+
const resultModelSet = await aiclient.modelset('mymodel-batching', model);
371+
expect(resultModelSet).to.equal('OK');
372+
const modelOut: Model = await aiclient.modelget('mymodel-batching');
373+
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop', modelOut);
374+
expect(resultModelSet2).to.equal('OK');
375+
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
376+
expect(modelOut.batchsize).to.equal(model.batchsize);
377+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
378+
aiclient.end(true);
379+
}),
380+
);
381+
382+
it(
383+
'ai.modelget batching via constructor positive testing',
384+
mochaAsync(async () => {
385+
const nativeClient = createClient();
386+
const aiclient = new Client(nativeClient);
387+
388+
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
389+
const inputs: string[] = ['a', 'b'];
390+
const outputs: string[] = ['c'];
391+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 100, 5);
392+
model.tag = 'test_tag';
393+
const resultModelSet = await aiclient.modelset('mymodel-batching-t2', model);
394+
expect(resultModelSet).to.equal('OK');
395+
const modelOut: Model = await aiclient.modelget('mymodel-batching-t2');
396+
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop-t2', modelOut);
397+
expect(resultModelSet2).to.equal('OK');
398+
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
399+
expect(modelOut.batchsize).to.equal(model.batchsize);
400+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
401+
402+
const model2 = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 1000);
403+
expect(model2.batchsize).to.equal(1000);
404+
expect(model2.minbatchsize).to.equal(0);
341405
aiclient.end(true);
342406
}),
343407
);

0 commit comments

Comments
 (0)