Skip to content

Commit

Permalink
[FEAT][BitMoE]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 28, 2024
1 parent 9c3e7dc commit 6cb3e82
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 284 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,25 @@ print(output.shape)

```

## `BitMoE`

```python
import torch
from bitnet.bit_moe import BitMoE

# Create input tensor
x = torch.randn(2, 4, 8)

# Create BitMoE model with specified input and output dimensions
model = BitMoE(8, 4, 2)

# Forward pass through the model
output = model(x)

# Print the output
print(output)
```

# License
MIT

Expand Down
14 changes: 14 additions & 0 deletions bit_moe_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from bitnet.bit_moe import BitMoE

# Create input tensor
x = torch.randn(2, 4, 8)

# Create BitMoE model with specified input and output dimensions
model = BitMoE(8, 4, 2)

# Forward pass through the model
output = model(x)

# Print the output
print(output)
2 changes: 2 additions & 0 deletions bitnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bitnet.replace_hf import replace_linears_in_hf, replace_linears_in_pytorch_model
from bitnet.bit_lora import BitLora
from bitnet.bit_mamba import BitMamba
from bitnet.bit_moe import BitMoE

__all__ = [
"BitFeedForward",
Expand All @@ -19,4 +20,5 @@
"BitLinear",
"BitLora",
"BitMamba",
"BitMoE",
]
3 changes: 1 addition & 2 deletions bitnet/bit_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def forward(self, x):
# Return Tokens
if self.return_tokens:
x = OutputHead(self.config.dim, -1)(x)
return x
return x
else:
return x

Expand Down Expand Up @@ -639,4 +639,3 @@ def __init__(

def forward(self, x):
return self.mamba(x)

Loading

0 comments on commit 6cb3e82

Please sign in to comment.