Skip to content

Commit

Permalink
support download resume
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyberhan123 committed Jan 9, 2024
1 parent 3bf181c commit 549119e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 29 deletions.
5 changes: 3 additions & 2 deletions .idea/hf-hub.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 48 additions & 25 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,12 @@ const (
)

type Cache struct {
path string
path string
resume bool
}

func NewCache(path string) *Cache {
//if !filepath.IsAbs(path) {
// path, err := filepath.Abs(path)
// if err != nil {
// return nil, err
// }
// return &Cache{path: path}, nil
//}
return &Cache{path: path}
func NewCache(path string, resume bool) *Cache {
return &Cache{path: path, resume: resume}
}

func DefaultCache() (*Cache, error) {
Expand All @@ -54,7 +48,7 @@ func DefaultCache() (*Cache, error) {
homePath = filepath.Join(homePath, ".cache", "huggingface")
}
cachePath := filepath.Join(homePath, "hub")
return NewCache(cachePath), nil
return NewCache(cachePath, true), nil
}

func (c *Cache) Path() string {
Expand All @@ -67,9 +61,8 @@ func (c *Cache) TokenPath() string {

func (c *Cache) Token() (string, error) {
tokenPath := c.TokenPath()

if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
log.Println("Token file not found")
log.Println("auth token file not found")
return "", nil
}

Expand Down Expand Up @@ -100,18 +93,28 @@ func (c *Cache) Space(modelId string) *CacheRepo {
return c.Repo(NewRepo(modelId, Space))
}

func (c *Cache) TempPath() (string, error) {
func (c *Cache) TempPath(filename string) (string, error) {
path := filepath.Join(c.path, "tmp")
err := os.MkdirAll(path, os.ModePerm)
if err != nil {
return "", err
}
path = filepath.Join(path, randStr(7))

if len(filename) > 0 {
if c.resume {
path = filepath.Join(path, filename+".income")
} else {
path = filepath.Join(path, filename+randStr(7))
}
} else {
path = filepath.Join(path, randStr(7))
}

return path, nil
}

func (c *Cache) Clone() *Cache {
newCache := NewCache(c.path)
newCache := NewCache(c.path, c.resume)
return newCache
}

Expand Down Expand Up @@ -269,6 +272,7 @@ type ApiBuilder struct {
parallelFailures uint64
maxRetries uint64
progress bool
headers http.Header
}

func NewApiBuilder() (*ApiBuilder, error) {
Expand All @@ -292,14 +296,12 @@ func (b *ApiBuilder) FromCache(cache *Cache) (*ApiBuilder, error) {
}

token, err := cache.Token()

if err != nil {
return nil, err
}

return &ApiBuilder{
endpoint: "https://huggingface.co",
//"{endpoint}/{repo_id}/resolve/{revision}/{filename}"
endpoint: "https://huggingface.co",
urlTemplate: "{{.Endpoint}}/{{.RepoId}}/resolve/{{.Revision}}/{{.Filename}}",
cache: cache,
token: token,
Expand All @@ -318,7 +320,13 @@ func (b *ApiBuilder) WithProgress(progress bool) *ApiBuilder {
}

func (b *ApiBuilder) WithCacheDir(cacheDir string) *ApiBuilder {
cache := NewCache(cacheDir)
cache := NewCache(cacheDir, b.cache.resume)
b.cache = cache
return b
}

func (b *ApiBuilder) WithResume(resume bool) *ApiBuilder {
cache := NewCache(b.cache.path, resume)
b.cache = cache
return b
}
Expand Down Expand Up @@ -391,6 +399,7 @@ type Api struct {
client *http.Client
noCDNRedirectClient *http.Client
progress bool
meta *Metadata
}

func NewApi() (*Api, error) {
Expand Down Expand Up @@ -463,19 +472,21 @@ func (a *Api) metadata(url string) (*Metadata, error) {
return nil, err
}

return &Metadata{
a.meta = &Metadata{
commitHash: commitHash,
etag: etag,
size: size,
}, nil
}
return a.meta, nil
}

func (a *Api) downloadTempFile(url string, progressbar *progressbar.ProgressBar) (string, error) {
filename, err := a.cache.TempPath()
filename, err := a.cache.TempPath(a.meta.etag)
if err != nil {
return "", err
}

file, err := os.Create(filename)
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
return "", err
}
Expand All @@ -487,6 +498,15 @@ func (a *Api) downloadTempFile(url string, progressbar *progressbar.ProgressBar)
}

req.Header = a.headers.Clone()

stat, _ := file.Stat()
if stat.Size() > 0 {
if a.meta.size > uint64(stat.Size()) {
progressbar.Set64(stat.Size())
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", stat.Size()))
}
}

res, err := a.client.Do(req)
if err != nil {
return "", err
Expand All @@ -505,8 +525,9 @@ func (a *Api) downloadTempFile(url string, progressbar *progressbar.ProgressBar)
return "", err
}

return filename, nil
return file.Name(), nil
}

func (a *Api) Repo(rep *Repo) *ApiRepo {
return NewApiRepo(a.Clone(), rep)
}
Expand Down Expand Up @@ -614,6 +635,8 @@ func (r *ApiRepo) Download(filename string) (string, error) {
int64(metadata.size),
progressbar.OptionSetDescription(message),
progressbar.OptionUseANSICodes(useANSICodes),
progressbar.OptionSetPredictTime(true),
progressbar.OptionShowBytes(true),
)
}

Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestApi_Dataset(t *testing.T) {
return
}

assert(t, sha256, "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90")
assert(t, sha256, "abdfc9f83b1103b502924072460d4c92f277c9b49c313cef3e48cfcf7428e125")
}

func TestApi_Model(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
filename="config.json",
revision="main",
cache_dir="../tmp",
endpoint="https://hf-mirror.com",
resume_download=True,
)


Expand Down

0 comments on commit 549119e

Please sign in to comment.