Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added ability to use config file to shard vicuna #1565

Merged
merged 1 commit into from
Jun 22, 2023

Conversation

Eliasj42
Copy link
Contributor

you can now pass a config file generated with github.com/nod-ai/SHARK/blob/main/shark/shark_generate_model_config.py

@powderluv
Copy link
Contributor

@dan-garvey / @PhaneeshB please review / merge

config_json = json.load(config_file)
config_file.close()
else:
config_json = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load config if given

) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.tokenizer = self.get_tokenizer()
self.config = config_json
Copy link
Contributor Author

@Eliasj42 Eliasj42 Jun 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

give Vicuna class config property so device indices can be accessed

idx_votes[int(self.config[key]["gpu"])] = 1
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define function to extract the device index from config file.

The config stores the model in its most granular state, so this function looks for the device used for every layer in the shard, and will use majority vote if different devices are used throughout the shard

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elias your comments are really good. But you should put them in the code instead of the review!

module = SharkInference(
mlirs[idx],
device=device,
device_idx=idx % 1,
device_idx=device_idx,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change device index to use config file instead of defaulting to 0

dan-garvey
dan-garvey previously approved these changes Jun 22, 2023
Copy link
Member

@dan-garvey dan-garvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to give another review if you want to add some of your comments to the code = )

idx_votes[int(self.config[key]["gpu"])] = 1
device_idx = max(idx_votes, key=idx_votes.get)
return device_idx

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elias your comments are really good. But you should put them in the code instead of the review!

@Eliasj42
Copy link
Contributor Author

I added comments.

Also, I realized that the embbeding and decoding layers weren't configurable, so I added in functionality for that

@dan-garvey dan-garvey merged commit 8822b9a into main Jun 22, 2023
@dan-garvey dan-garvey deleted the add-sharding-config-2 branch November 3, 2023 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants