Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new optimizer state row_counter for Adam [Backend] #3342

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

spcyppt
Copy link
Contributor

@spcyppt spcyppt commented Nov 8, 2024

Summary:
A new optional optimizer state row_counter is added to Adam to perform bias correction per embedding row. row_counter serves as the iteration counter when a row (an index) occurs and used to do bias correction.

Without rowwise bias correction (existing Adam),

m_hat_t = m_t / (1.0 - powf(beta1, iter));
v_hat_t = v_t / (1.0 - powf(beta2, iter));

With rowwise bias correction enabled.

// when index `idx` occurs
_row_counter = row_counter[idx] + 1;
m_hat_t = m_t / (1.0 - powf(beta1, _row_counter));
v_hat_t = v_t / (1.0 - powf(beta2, _row_counter));

This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories.


The functionality is not set by default.

To enable the bias correction, use_rowwise_bias_correction needs to be set to True through extra_optimizer_config.

extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True)
emb_op = SplitTableBatchedEmbeddingBagsCodegen
(
            embedding_specs=[
                (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
            ],
            optimizer=OptimType.Adam
            extra_optimizer_config=extra_optimizer_config,
            ...
)

Differential Revision: D64808460


@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64808460

Copy link

netlify bot commented Nov 8, 2024

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit f2cf409
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/673431499feb1600084ec209
😎 Deploy Preview https://deploy-preview-3342--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Nov 8, 2024
Summary:

X-link: facebookresearch/FBGEMM#436

A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as  the iteration counter when a row (an index) occurs and used to do bias correction.


Without rowwise bias correction (existing Adam),
```
m_hat_t = m_t / (1.0 - powf(beta1, iter));
v_hat_t = v_t / (1.0 - powf(beta2, iter));
```

With rowwise bias correction enabled.
```
// when index `idx` occurs
_row_counter = row_counter[idx] + 1;
m_hat_t = m_t / (1.0 - powf(beta1, _row_counter));
v_hat_t = v_t / (1.0 - powf(beta2, _row_counter));
```

This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories.

-------

**__The functionality is not set by default.__** Frontend: D64848802

To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. 
```
extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True)
emb_op = SplitTableBatchedEmbeddingBagsCodegen
(
            embedding_specs=[
                (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
            ],
            optimizer=OptimType.Adam
            extra_optimizer_config=extra_optimizer_config,
            ...
)

Differential Revision: D64808460
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64808460

Summary:

X-link: facebookresearch/FBGEMM#436

A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as  the iteration counter when a row (an index) occurs and used to do bias correction.


Without rowwise bias correction (existing Adam),
```
m_hat_t = m_t / (1.0 - powf(beta1, iter));
v_hat_t = v_t / (1.0 - powf(beta2, iter));
```

With rowwise bias correction enabled.
```
// when index `idx` occurs
_row_counter = row_counter[idx] + 1;
m_hat_t = m_t / (1.0 - powf(beta1, _row_counter));
v_hat_t = v_t / (1.0 - powf(beta2, _row_counter));
```

This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories.

-------

**__The functionality is not set by default.__** Frontend: D64848802

To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. 
```
extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True)
emb_op = SplitTableBatchedEmbeddingBagsCodegen
(
            embedding_specs=[
                (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
            ],
            optimizer=OptimType.Adam
            extra_optimizer_config=extra_optimizer_config,
            ...
)
```
------
**__Performance from Kineto__** (unweighted)
```
                   Baseline* |  default** | enabled*** 
forward  | cpu  |   2.293 s  |   2.188 s  |   2.043 s
         | cuda |  12.512 ms |  12.539 ms |  12.547 ms
backward | cpu  |  69.861 ms |  66.546 ms |  65.880 ms
         | cuda | 103.429 ms | 103.395 ms | 103.130 ms
```
\* Baseline: before changes
\** default: default setting; use_bias_correction = False
\*** enabled: use_bias_correction = True

Reviewed By: sryap

Differential Revision: D64808460
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64808460

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants