-
Notifications
You must be signed in to change notification settings - Fork 54
/
main.go
305 lines (270 loc) · 9.94 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"time"
hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader"
"github.com/joho/godotenv"
"github.com/spf13/cobra"
)
const VERSION = "1.4.2"
type Config struct {
NumConnections int `json:"num_connections"`
RequiresAuth bool `json:"requires_auth"`
AuthToken string `json:"auth_token"`
ModelName string `json:"model_name"`
DatasetName string `json:"dataset_name"`
Branch string `json:"branch"`
Storage string `json:"storage"`
OneFolderPerFilter bool `json:"one_folder_per_filter"`
SkipSHA bool `json:"skip_sha"`
// Install bool `json:"install"`
// InstallPath string `json:"install_path"`
MaxRetries int `json:"max_retries"`
RetryInterval int `json:"retry_interval"`
JustDownload bool `json:"just_download"`
SilentMode bool `json:"silent_mode"`
}
// DefaultConfig returns a config instance populated with default values.
func DefaultConfig() Config {
return Config{
NumConnections: 5,
Branch: "main",
Storage: "./",
MaxRetries: 3,
RetryInterval: 5,
}
}
func LoadConfig() (*Config, error) {
config := DefaultConfig() // Use defaults as a base
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, err
}
configPath := filepath.Join(homeDir, ".config", "hfdownloader.json")
file, err := os.ReadFile(configPath)
if os.IsNotExist(err) {
return &config, nil // Return defaults if file does not exist
} else if err == nil {
if err := json.Unmarshal(file, &config); err != nil {
return nil, err
}
}
// Check if an environment variable to always enable the 'just download' feature is enabled
envVar := os.Getenv("HFDOWNLOADER_JUST_DOWNLOAD")
if envVar == "1" || envVar == "true" {
config.Storage = "./" // Set storage to current directory
}
return &config, nil
}
func generateConfigFile() error {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(homeDir, ".config", "hfdownloader.json")
config := DefaultConfig()
file, err := os.OpenFile(configPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer file.Close()
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(config); err != nil {
return err
}
fmt.Printf("Generated config file at: %s\n", configPath)
return nil
}
func main() {
config, err := LoadConfig()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
var justDownload bool
var (
install bool
installPath string
)
ShortString := fmt.Sprintf("a Simple HuggingFace Models Downloader Utility\nVersion: %s", VERSION)
currentPath, err := os.Executable()
if err != nil {
log.Printf("Failed to get execuable path, %s", err)
}
if currentPath != "" {
ShortString = fmt.Sprintf("%s\nRunning on: %s", ShortString, currentPath)
}
rootCmd := &cobra.Command{
Use: "hfdownloader [model]",
Short: ShortString,
SilenceErrors: true,
SilenceUsage: true,
Args: func(cmd *cobra.Command, args []string) error {
if justDownload && len(args) < 1 {
return errors.New("requires a model name argument when using -j")
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
if justDownload {
config.ModelName = args[0] // Use the first argument as the model name
config.Storage = "./"
}
// Validate the ModelName parameter
// if !hfdn.IsValidModelName(modelName) { Just realized there are indeed models that don't follow this format :)
// // fmt.Println("Error:", err)
// return fmt.Errorf("Invailid Model Name, it should follow the pattern: ModelAuthor/ModelName")
// }
// Dynamic configuration updates (e.g., for AuthToken)
if config.AuthToken == "" {
config.AuthToken = os.Getenv("HF_TOKEN")
if config.AuthToken == "" {
config.AuthToken = os.Getenv("HUGGING_FACE_HUB_TOKEN")
if config.AuthToken != "" {
fmt.Println("DeprecationWarning: The environment variable 'HUGGING_FACE_HUB_TOKEN' is deprecated and will be removed in a future version. Please use 'HF_TOKEN' instead.")
}
}
}
if install {
if err := installBinary(installPath); err != nil {
log.Fatal(err)
}
os.Exit(0)
}
var IsDataset bool
ModelOrDataSet := config.ModelName
if config.ModelName != "" {
fmt.Println("Model:", config.ModelName)
IsDataset = false
} else if config.DatasetName != "" {
fmt.Println("Dataset:", config.DatasetName)
IsDataset = true
ModelOrDataSet = config.DatasetName
} else {
cmd.Help()
return fmt.Errorf("Error: You must set either modelName or datasetName.")
}
_ = godotenv.Load() // Load .env file if exists
if config.AuthToken == "" {
config.AuthToken = os.Getenv("HF_TOKEN")
if config.AuthToken == "" {
config.AuthToken = os.Getenv("HUGGING_FACE_HUB_TOKEN")
if config.AuthToken != "" {
fmt.Println("DeprecationWarning: The environment variable 'HUGGING_FACE_HUB_TOKEN' is deprecated and will be removed in a future version. Please use 'HF_TOKEN' instead.")
}
}
}
fmt.Printf("Branch: %s\nStorage: %s\nNumberOfConcurrentConnections: %d\nAppend Filter Names to Folder: %t\nSkip SHA256 Check: %t\nToken: %s\n",
config.Branch, config.Storage, config.NumConnections, config.OneFolderPerFilter, config.SkipSHA, config.AuthToken)
for i := 0; i < config.MaxRetries; i++ {
if err := hfd.DownloadModel(ModelOrDataSet, config.OneFolderPerFilter, config.SkipSHA, IsDataset, config.Storage, config.Branch, config.NumConnections, config.AuthToken, config.SilentMode); err != nil {
fmt.Printf("Warning: attempt %d / %d failed, error: %s\n", i+1, config.MaxRetries, err)
time.Sleep(time.Duration(config.RetryInterval) * time.Second)
continue
}
fmt.Printf("\nDownload of %s completed successfully\n", ModelOrDataSet)
return nil
}
return fmt.Errorf("failed to download %s after %d attempts", ModelOrDataSet, config.MaxRetries)
},
}
// Setup flags and bind them to config properties
rootCmd.PersistentFlags().StringVarP(&config.ModelName, "model", "m", config.ModelName, "Model name to download")
rootCmd.PersistentFlags().StringVarP(&config.DatasetName, "dataset", "d", config.DatasetName, "Dataset name to download")
rootCmd.PersistentFlags().StringVarP(&config.Branch, "branch", "b", config.Branch, "Branch of the model or dataset")
rootCmd.PersistentFlags().StringVarP(&config.Storage, "storage", "s", config.Storage, "Storage path for downloads")
rootCmd.PersistentFlags().IntVarP(&config.NumConnections, "concurrent", "c", config.NumConnections, "Number of concurrent connections")
rootCmd.PersistentFlags().StringVarP(&config.AuthToken, "token", "t", config.AuthToken, "HuggingFace Auth Token")
rootCmd.PersistentFlags().BoolVarP(&config.OneFolderPerFilter, "appendFilterFolder", "f", config.OneFolderPerFilter, "Append filter name to folder")
rootCmd.PersistentFlags().BoolVarP(&config.SkipSHA, "skipSHA", "k", config.SkipSHA, "Skip SHA256 hash check")
rootCmd.PersistentFlags().IntVar(&config.MaxRetries, "maxRetries", config.MaxRetries, "Maximum number of retries for downloads")
rootCmd.PersistentFlags().IntVar(&config.RetryInterval, "retryInterval", config.RetryInterval, "Interval between retries in seconds")
rootCmd.PersistentFlags().BoolVarP(&justDownload, "justDownload", "j", config.JustDownload, "Just download the model to the current directory and assume the first argument is the model name")
rootCmd.Flags().BoolVarP(&install, "install", "i", false, "Install the binary to the OS default bin folder, Unix-like operating systems only")
rootCmd.Flags().StringVarP(&installPath, "installPath", "p", "/usr/local/bin/", "install Path (optional)")
rootCmd.PersistentFlags().BoolVarP(&config.SilentMode, "silentMode", "q", config.SilentMode, "Disable progress bar output printing")
// Add the generate-config command
generateCmd := &cobra.Command{
Use: "generate-config",
Short: "Generates an example configuration file with default values",
RunE: func(cmd *cobra.Command, args []string) error {
return generateConfigFile()
},
}
rootCmd.AddCommand(generateCmd)
if err := rootCmd.Execute(); err != nil {
log.Fatalln("Error:", err)
}
}
func installBinary(installPath string) error {
if runtime.GOOS == "windows" {
return errors.New("the install command is not supported on Windows")
}
exePath, err := os.Executable()
if err != nil {
return err
}
dst := path.Join(installPath, filepath.Base(exePath))
// Check if we need sudo for either removing the existing binary or copying the new one
needsSudo := false
// Check if the binary already exists
if _, err := os.Stat(dst); err == nil {
// Try to remove the existing binary
err := os.Remove(dst)
if err != nil {
if os.IsPermission(err) {
needsSudo = true
} else {
return err
}
}
}
// Open the source file
srcFile, err := os.Open(exePath)
if err != nil {
return err
}
defer srcFile.Close()
// Try to copy the file
err = copyFile(dst, srcFile)
if err != nil {
if os.IsPermission(err) {
needsSudo = true
} else {
return err
}
}
// If we need sudo, handle both removal and copy with elevated privileges
if needsSudo {
fmt.Printf("Require sudo privileges to complete installation at: %s\n", installPath)
cmd := exec.Command("sudo", "sh", "-c", fmt.Sprintf("rm -f %s && cp %s %s", dst, exePath, dst))
if err := cmd.Run(); err != nil {
return err
}
}
log.Printf("The binary has been successfully installed to %s", dst)
return nil
}
// copyFile is a helper function to copy a file with specific permission
func copyFile(dst string, src *os.File) error {
// Open destination file and ensure it gets closed
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
if err != nil {
return err
}
defer dstFile.Close()
// Copy the file content
if _, err := io.Copy(dstFile, src); err != nil {
return err
}
return nil
}