Skip to content

Commit fb691af

Browse files
Support for DAGRUN and DAGRUN_RO (#8)
* [add] moved command arguments build to the respective class * [add] dagrun and dagrun_ro supported ( scriptrun included ) * [add] added ResNet-50 example to the Readme
1 parent ab65392 commit fb691af

File tree

13 files changed

+565
-39
lines changed

13 files changed

+565
-39
lines changed

README.md

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,73 @@ Example of AI.SCRIPTSET and AI.SCRIPTRUN
125125
})();
126126
```
127127

128+
Example of AI.DAGRUN enqueuing multiple SCRIPTRUN and MODELRUN commands
129+
130+
A common pattern is enqueuing multiple SCRIPTRUN and MODELRUN commands within a DAG. The following example uses ResNet-50,to classify images into 1000 object categories.
131+
132+
Given that our input tensor contains each color represented as a 8-bit integer and that neural networks usually work with floating-point tensors as their input we need to cast a tensor to floating-point and normalize the values of the pixels - for that we will use `pre_process_4ch` function.
133+
134+
To optimize the classification process we can use a post process script to return only the category position with the maximum classification - for that we will use `post_process` script.
135+
136+
Using the DAG capabilities we've removed the necessity of storing the intermediate tensors in the keyspace. You can even run the entire process without storing the output tensor, as follows:
137+
138+
139+
```javascript
140+
var redis = require('redis');
141+
var redisai = require('redisai-js');
142+
var fs = require("fs");
143+
144+
(async () => {
145+
const nativeClient = redis.createClient();
146+
const aiclient = new redisai.Client(nativeClient);
147+
const scriptFileStr = fs.readFileSync('./tests/test_data/imagenet/data_processing_script.txt').toString();
148+
const jsonLabels = fs.readFileSync('./tests/test_data/imagenet/imagenet_class_index.json');
149+
const labels = JSON.parse(jsonLabels);
150+
151+
const dataProcessingScript = new redisai.Script('CPU', scriptFileStr);
152+
const resultScriptSet = await aiclient.scriptset('data_processing_script', dataProcessingScript);
153+
// AI.SCRIPTSET result: OK
154+
console.log(`AI.SCRIPTSET result: ${resultScriptSet}`)
155+
156+
const modelBlob = fs.readFileSync('./tests/test_data/imagenet/resnet50.pb');
157+
const imagenetModel = new redisai.Model(Backend.TF, 'CPU', ['images'], ['output'], modelBlob);
158+
const resultModelSet = await aiclient.modelset('imagenet_model', imagenetModel);
159+
160+
// AI.MODELSET result: OK
161+
console.log(`AI.MODELSET result: ${resultModelSet}`)
162+
163+
const inputImage = await Jimp.read('./tests/test_data/imagenet/cat.jpg');
164+
const imageWidth = 224;
165+
const imageHeight = 224;
166+
const image = inputImage.cover(imageWidth, imageHeight);
167+
const tensor = new redisai.Tensor(Dtype.uint8, [imageWidth, imageHeight, 4], Buffer.from(image.bitmap.data));
168+
169+
///
170+
// Prepare the DAG enqueuing multiple SCRIPTRUN and MODELRUN commands
171+
const dag = new redisai.Dag();
172+
173+
dag.tensorset('tensor-image', tensor);
174+
dag.scriptrun('data_processing_script', 'pre_process_4ch', ['tensor-image'], ['temp_key1']);
175+
dag.modelrun('imagenet_model', ['temp_key1'], ['temp_key2']);
176+
dag.scriptrun('data_processing_script', 'post_process', ['temp_key2'], ['classification']);
177+
dag.tensorget('classification');
178+
179+
// Send the AI.DAGRUN command to RedisAI server
180+
const resultDagRun = await aiclient.dagrun_ro(null, dag);
181+
182+
// The 5th element of the reply will be the `classification` tensor
183+
const classTensor = resultDagRun[4];
184+
185+
// Print the category in the position with the max classification
186+
const idx = classTensor.data[0];
187+
188+
// 281 [ 'n02123045', 'tabby' ]
189+
console.log(idx, labels[idx.toString()]);
190+
191+
await aiclient.end();
192+
})();
193+
```
194+
128195
### Further examples
129196

130197
The [RedisAI examples repo](https://github.com/RedisAI/redisai-examples) shows more advanced examples
@@ -147,8 +214,8 @@ AI.SCRIPTGET | scriptget
147214
AI.SCRIPTDEL | scriptdel
148215
AI.SCRIPTRUN | scriptrun
149216
AI._SCRIPTSCAN | N/A
150-
AI.DAGRUN | N/A
151-
AI.DAGRUN_RO | N/A
217+
AI.DAGRUN | dagrun
218+
AI.DAGRUN_RO | dagrun_ro
152219
AI.INFO | info and infoResetStat (for resetting stats)
153220
AI.CONFIG * | configLoadBackend and configBackendsPath
154221

src/client.ts

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ import { Model } from './model';
44
import * as util from 'util';
55
import { Script } from './script';
66
import { Stats } from './stats';
7+
import { Dag } from './dag';
78

89
export class Client {
9-
private _sendCommand: any;
10+
private readonly _sendCommand: any;
1011

1112
constructor(client: RedisClient) {
1213
this._client = client;
1314
this._sendCommand = util.promisify(this._client.send_command).bind(this._client);
1415
}
1516

16-
private _client: RedisClient;
17+
private readonly _client: RedisClient;
1718

1819
get client(): RedisClient {
1920
return this._client;
@@ -23,23 +24,13 @@ export class Client {
2324
this._client.end(flush);
2425
}
2526

26-
public tensorset(keName: string, t: Tensor): Promise<any> {
27-
const args: any[] = [keName, t.dtype];
28-
t.shape.forEach((value) => args.push(value.toString()));
29-
if (t.data != null) {
30-
if (t.data instanceof Buffer) {
31-
args.push('BLOB');
32-
args.push(t.data);
33-
} else {
34-
args.push('VALUES');
35-
t.data.forEach((value) => args.push(value.toString()));
36-
}
37-
}
27+
public tensorset(keyName: string, t: Tensor): Promise<any> {
28+
const args: any[] = t.tensorSetFlatArgs(keyName);
3829
return this._sendCommand('ai.tensorset', args);
3930
}
4031

41-
public tensorget(keName: string): Promise<any> {
42-
const args: any[] = [keName, 'META', 'VALUES'];
32+
public tensorget(keyName: string): Promise<any> {
33+
const args: any[] = Tensor.tensorGetFlatArgs(keyName);
4334
return this._sendCommand('ai.tensorget', args)
4435
.then((reply: any[]) => {
4536
return Tensor.NewTensorFromTensorGetReply(reply);
@@ -55,10 +46,7 @@ export class Client {
5546
}
5647

5748
public modelrun(modelName: string, inputs: string[], outputs: string[]): Promise<any> {
58-
const args: any[] = [modelName, 'INPUTS'];
59-
inputs.forEach((value) => args.push(value));
60-
args.push('OUTPUTS');
61-
outputs.forEach((value) => args.push(value));
49+
const args: any[] = Model.modelRunFlatArgs(modelName, inputs, outputs);
6250
return this._sendCommand('ai.modelrun', args);
6351
}
6452

@@ -68,7 +56,7 @@ export class Client {
6856
}
6957

7058
public modelget(modelName: string): Promise<any> {
71-
const args: any[] = [modelName, 'META', 'BLOB'];
59+
const args: any[] = Model.modelGetFlatArgs(modelName);
7260
return this._sendCommand('ai.modelget', args)
7361
.then((reply: any[]) => {
7462
return Model.NewModelFromModelGetReply(reply);
@@ -78,22 +66,13 @@ export class Client {
7866
});
7967
}
8068

81-
public scriptset(keName: string, s: Script): Promise<any> {
82-
const args: any[] = [keName, s.device];
83-
if (s.tag !== undefined) {
84-
args.push('TAG');
85-
args.push(s.tag);
86-
}
87-
args.push('SOURCE');
88-
args.push(s.script);
69+
public scriptset(keyName: string, s: Script): Promise<any> {
70+
const args: any[] = s.scriptSetFlatArgs(keyName);
8971
return this._sendCommand('ai.scriptset', args);
9072
}
9173

9274
public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): Promise<any> {
93-
const args: any[] = [scriptName, functionName, 'INPUTS'];
94-
inputs.forEach((value) => args.push(value));
95-
args.push('OUTPUTS');
96-
outputs.forEach((value) => args.push(value));
75+
const args: any[] = Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs);
9776
return this._sendCommand('ai.scriptrun', args);
9877
}
9978

@@ -103,7 +82,7 @@ export class Client {
10382
}
10483

10584
public scriptget(scriptName: string): Promise<any> {
106-
const args: any[] = [scriptName, 'META', 'SOURCE'];
85+
const args: any[] = Script.scriptGetFlatArgs(scriptName);
10786
return this._sendCommand('ai.scriptget', args)
10887
.then((reply: any[]) => {
10988
return Script.NewScriptFromScriptGetReply(reply);
@@ -137,6 +116,40 @@ export class Client {
137116
});
138117
}
139118

119+
/**
120+
* specifies a direct acyclic graph of operations to run within RedisAI
121+
*
122+
* @param loadKeys
123+
* @param persistKeys
124+
* @param dag
125+
*/
126+
public dagrun(loadKeys: string[] | null, persistKeys: string[] | null, dag: Dag): Promise<any> {
127+
const args: any[] = dag.dagRunFlatArgs(loadKeys, persistKeys);
128+
return this._sendCommand('ai.dagrun', args)
129+
.then((reply: any[]) => {
130+
return dag.ProcessDagReply(reply);
131+
})
132+
.catch((error: any) => {
133+
throw error;
134+
});
135+
}
136+
137+
/**
138+
* specifies a Read Only direct acyclic graph of operations to run within RedisAI
139+
*
140+
* @param loadKeys
141+
* @param dag
142+
*/
143+
public dagrun_ro(loadKeys: string[] | null, dag: Dag): Promise<any> {
144+
const args: any[] = dag.dagRunFlatArgs(loadKeys, null);
145+
return this._sendCommand('ai.dagrun_ro', args)
146+
.then((reply: any[]) => {
147+
return dag.ProcessDagReply(reply);
148+
})
149+
.catch((error: any) => {
150+
throw error;
151+
});
152+
}
140153
/**
141154
* Loads the DL/ML backend specified by the backend identifier from path.
142155
*

src/dag.ts

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import { Model } from './model';
2+
import { Script } from './script';
3+
import { Tensor } from './tensor';
4+
5+
export interface DagCommandInterface {
6+
tensorset(keyName: string, t: Tensor): DagCommandInterface;
7+
8+
tensorget(keyName: string): DagCommandInterface;
9+
10+
tensorget(keyName: string): DagCommandInterface;
11+
12+
modelrun(modelName: string, inputs: string[], outputs: string[]): DagCommandInterface;
13+
14+
scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): DagCommandInterface;
15+
}
16+
17+
/**
18+
* Direct mapping to RedisAI DAGs
19+
*/
20+
export class Dag implements DagCommandInterface {
21+
private _commands: any[][];
22+
private readonly _tensorgetflag: boolean[];
23+
24+
constructor() {
25+
this._commands = [];
26+
this._tensorgetflag = [];
27+
}
28+
29+
public tensorset(keyName: string, t: Tensor): Dag {
30+
const args: any[] = ['AI.TENSORSET'];
31+
t.tensorSetFlatArgs(keyName).forEach((arg) => args.push(arg));
32+
this._commands.push(args);
33+
this._tensorgetflag.push(false);
34+
return this;
35+
}
36+
37+
public tensorget(keyName: string): Dag {
38+
const args: any[] = ['AI.TENSORGET'];
39+
Tensor.tensorGetFlatArgs(keyName).forEach((arg) => args.push(arg));
40+
this._commands.push(args);
41+
this._tensorgetflag.push(true);
42+
return this;
43+
}
44+
45+
public modelrun(modelName: string, inputs: string[], outputs: string[]): Dag {
46+
const args: any[] = ['AI.MODELRUN'];
47+
Model.modelRunFlatArgs(modelName, inputs, outputs).forEach((arg) => args.push(arg));
48+
this._commands.push(args);
49+
this._tensorgetflag.push(false);
50+
return this;
51+
}
52+
53+
public scriptrun(scriptName: string, functionName: string, inputs: string[], outputs: string[]): Dag {
54+
const args: any[] = ['AI.SCRIPTRUN'];
55+
Script.scriptRunFlatArgs(scriptName, functionName, inputs, outputs).forEach((arg) => args.push(arg));
56+
this._commands.push(args);
57+
this._tensorgetflag.push(false);
58+
return this;
59+
}
60+
61+
public dagRunFlatArgs(loadKeys: string[] | null, persistKeys: string[] | null): string[] {
62+
const args: any[] = [];
63+
if (loadKeys != null && loadKeys.length > 0) {
64+
args.push('LOAD');
65+
args.push(loadKeys.length);
66+
loadKeys.forEach((value) => args.push(value));
67+
}
68+
if (persistKeys != null && persistKeys.length > 0) {
69+
args.push('PERSIST');
70+
args.push(persistKeys.length);
71+
persistKeys.forEach((value) => args.push(value));
72+
}
73+
this._commands.forEach((value) => {
74+
args.push('|>');
75+
value.forEach((arg) => args.push(arg));
76+
});
77+
return args;
78+
}
79+
80+
public ProcessDagReply(reply: any[]): any[] {
81+
for (let i = 0; i < reply.length; i++) {
82+
if (this._tensorgetflag[i] === true) {
83+
reply[i] = Tensor.NewTensorFromTensorGetReply(reply[i]);
84+
}
85+
}
86+
return reply;
87+
}
88+
}

src/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ import { Backend, BackendMap } from './backend';
33
import { Tensor } from './tensor';
44
import { Model } from './model';
55
import { Script } from './script';
6+
import { Dag } from './dag';
67
import { Client } from './client';
78
import { Stats } from './stats';
89
import { Helpers } from './helpers';
910

10-
export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Client, Stats, Helpers };
11+
export { DTypeMap, Dtype, BackendMap, Backend, Model, Script, Tensor, Dag, Client, Stats, Helpers };

src/model.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,19 @@ export class Model {
187187
return model;
188188
}
189189

190-
modelSetFlatArgs(keyName: string) {
190+
static modelGetFlatArgs(keyName: string): string[] {
191+
return [keyName, 'META', 'BLOB'];
192+
}
193+
194+
static modelRunFlatArgs(modelName: string, inputs: string[], outputs: string[]): string[] {
195+
const args: string[] = [modelName, 'INPUTS'];
196+
inputs.forEach((value) => args.push(value));
197+
args.push('OUTPUTS');
198+
outputs.forEach((value) => args.push(value));
199+
return args;
200+
}
201+
202+
modelSetFlatArgs(keyName: string): any[] {
191203
const args: any[] = [keyName, this.backend.toString(), this.device];
192204
if (this.tag !== undefined) {
193205
args.push('TAG');

src/script.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,28 @@ export class Script {
8787
}
8888
return script;
8989
}
90+
91+
scriptSetFlatArgs(keyName: string): string[] {
92+
const args: string[] = [keyName, this.device];
93+
if (this.tag !== undefined) {
94+
args.push('TAG');
95+
args.push(this.tag);
96+
}
97+
args.push('SOURCE');
98+
args.push(this.script);
99+
return args;
100+
}
101+
102+
static scriptRunFlatArgs(scriptName: string, functionName: string, inputs: string[], outputs: string[]): string[] {
103+
const args: string[] = [scriptName, functionName, 'INPUTS'];
104+
inputs.forEach((value) => args.push(value));
105+
args.push('OUTPUTS');
106+
outputs.forEach((value) => args.push(value));
107+
return args;
108+
}
109+
110+
static scriptGetFlatArgs(scriptName: string): string[] {
111+
const args: string[] = [scriptName, 'META', 'SOURCE'];
112+
return args;
113+
}
90114
}

0 commit comments

Comments
 (0)