Predictor
Predictor
¶
Predictor class managing distributed inference process.
Attributes:
Name | Type | Description |
---|---|---|
mesh |
jax Mesh
|
Mesh used for distributed inference. |
Source code in redco/predictors/predictor.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
|
mesh
property
¶
Returns the mesh used for distributed inference.
__init__(deployer, collate_fn, pred_fn, output_fn=None, params_sharding_rules=None)
¶
Initializes a Predictor instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployer |
Deployer
|
A deployer for low-level operations. |
required |
collate_fn |
Callable
|
A function making model inputs from raw data, e.g., tokenizing sentences into input_ids. |
required |
pred_fn |
Callable
|
A function producing model outputs from inputs, e.g., running beam search with a language model. |
required |
output_fn |
Callable
|
A function post-processing model outputs, e.g., decoding generated ids to text. |
None
|
params_sharding_rules |
PyTree
|
Rules for sharding parameters. |
None
|
Source code in redco/predictors/predictor.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
predict(examples, per_device_batch_size, params, params_replicated=False, params_sharded=False, desc=None)
¶
Runs distributed prediction on a list of examples.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
examples |
list
|
Input examples for prediction. |
required |
per_device_batch_size |
int
|
Batch size per device. |
required |
params |
dict
|
Model parameters in a dict/FrozenDict. |
required |
params_replicated |
bool
|
if the params are already replicated. |
False
|
params_sharded |
bool
|
if the parameters are already sharded. |
False
|
desc |
str
|
Description to show in the progress bar. |
None
|
Returns:
Type | Description |
---|---|
list
|
A list of predictions corresponding to the input examples. |
Source code in redco/predictors/predictor.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
setup_running_step(dummy_batch, params_shape_or_params)
¶
Sets up the prediction step function for distributed inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dummy_batch |
PyTree
|
A dummy batch used to determine data shapes. |
required |
params_shape_or_params |
dict
|
The shape of params or actual params. |
required |
Source code in redco/predictors/predictor.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|