-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added ability to use config file to shard vicuna #1565
Conversation
43a6a4a
to
cd270d5
Compare
@dan-garvey / @PhaneeshB please review / merge |
config_json = json.load(config_file) | ||
config_file.close() | ||
else: | ||
config_json = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load config if given
) -> None: | ||
super().__init__(model_name, hf_model_path, max_num_tokens) | ||
self.max_sequence_length = 256 | ||
self.device = device | ||
self.precision = precision | ||
self.tokenizer = self.get_tokenizer() | ||
self.config = config_json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
give Vicuna class config property so device indices can be accessed
idx_votes[int(self.config[key]["gpu"])] = 1 | ||
device_idx = max(idx_votes, key=idx_votes.get) | ||
return device_idx | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define function to extract the device index from config file.
The config stores the model in its most granular state, so this function looks for the device used for every layer in the shard, and will use majority vote if different devices are used throughout the shard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Elias your comments are really good. But you should put them in the code instead of the review!
module = SharkInference( | ||
mlirs[idx], | ||
device=device, | ||
device_idx=idx % 1, | ||
device_idx=device_idx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change device index to use config file instead of defaulting to 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to give another review if you want to add some of your comments to the code = )
idx_votes[int(self.config[key]["gpu"])] = 1 | ||
device_idx = max(idx_votes, key=idx_votes.get) | ||
return device_idx | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Elias your comments are really good. But you should put them in the code instead of the review!
cd270d5
to
cb96194
Compare
I added comments. Also, I realized that the embbeding and decoding layers weren't configurable, so I added in functionality for that |
you can now pass a config file generated with github.com/nod-ai/SHARK/blob/main/shark/shark_generate_model_config.py