Skip to content

Commit

Permalink
support transaction tag filter type
Browse files Browse the repository at this point in the history
  • Loading branch information
mayswind committed Dec 7, 2024
1 parent 5003f8b commit dd35a85
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 91 deletions.
8 changes: 4 additions & 4 deletions pkg/api/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (a *TransactionsApi) TransactionCountHandler(c *core.WebContext) (any, *err
}
}

totalCount, err := a.transactions.GetTransactionCount(c, uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionCountReq.AmountFilter, transactionCountReq.Keyword)
totalCount, err := a.transactions.GetTransactionCount(c, uid, transactionCountReq.MaxTime, transactionCountReq.MinTime, transactionCountReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionCountReq.TagFilterType, transactionCountReq.AmountFilter, transactionCountReq.Keyword)

if err != nil {
log.Errorf(c, "[transactions.TransactionCountHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error())
Expand Down Expand Up @@ -160,15 +160,15 @@ func (a *TransactionsApi) TransactionListHandler(c *core.WebContext) (any, *errs
var totalCount int64

if transactionListReq.WithCount {
totalCount, err = a.transactions.GetTransactionCount(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.AmountFilter, transactionListReq.Keyword)
totalCount, err = a.transactions.GetTransactionCount(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.TagFilterType, transactionListReq.AmountFilter, transactionListReq.Keyword)

if err != nil {
log.Errorf(c, "[transactions.TransactionListHandler] failed to get transaction count for user \"uid:%d\", because %s", uid, err.Error())
return nil, errs.Or(err, errs.ErrOperationFailed)
}
}

transactions, err := a.transactions.GetTransactionsByMaxTime(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.AmountFilter, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true)
transactions, err := a.transactions.GetTransactionsByMaxTime(c, uid, transactionListReq.MaxTime, transactionListReq.MinTime, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.TagFilterType, transactionListReq.AmountFilter, transactionListReq.Keyword, transactionListReq.Page, transactionListReq.Count, true, true)

if err != nil {
log.Errorf(c, "[transactions.TransactionListHandler] failed to get transactions earlier than \"%d\" for user \"uid:%d\", because %s", transactionListReq.MaxTime, uid, err.Error())
Expand Down Expand Up @@ -260,7 +260,7 @@ func (a *TransactionsApi) TransactionMonthListHandler(c *core.WebContext) (any,
}
}

transactions, err := a.transactions.GetTransactionsInMonthByPage(c, uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.AmountFilter, transactionListReq.Keyword)
transactions, err := a.transactions.GetTransactionsInMonthByPage(c, uid, transactionListReq.Year, transactionListReq.Month, transactionListReq.Type, allCategoryIds, allAccountIds, allTagIds, noTags, transactionListReq.TagFilterType, transactionListReq.AmountFilter, transactionListReq.Keyword)

if err != nil {
log.Errorf(c, "[transactions.TransactionMonthListHandler] failed to get transactions in month \"%d-%d\" for user \"uid:%d\", because %s", transactionListReq.Year, transactionListReq.Month, uid, err.Error())
Expand Down
84 changes: 49 additions & 35 deletions pkg/models/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ func (s TransactionDbType) ToTransactionType() (TransactionType, error) {
}
}

// TransactionTagFilterType represents transaction tag filter type
type TransactionTagFilterType byte

// Transaction tag filter types
const (
TRANSACTION_TAG_FILTER_HAS_ANY TransactionTagFilterType = 0
TRANSACTION_TAG_FILTER_HAS_ALL TransactionTagFilterType = 1
TRANSACTION_TAG_FILTER_NOT_HAS_ANY TransactionTagFilterType = 2
TRANSACTION_TAG_FILTER_NOT_HAS_ALL TransactionTagFilterType = 3
)

// Transaction represents transaction data stored in database
type Transaction struct {
TransactionId int64 `xorm:"PK"`
Expand Down Expand Up @@ -140,49 +151,52 @@ type TransactionImportRequest struct {

// TransactionCountRequest represents transaction count request
type TransactionCountRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
MaxTime int64 `form:"max_time" binding:"min=0"`
MinTime int64 `form:"min_time" binding:"min=0"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
TagFilterType TransactionTagFilterType `form:"tag_filter_type" binding:"min=0,max=3"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
MaxTime int64 `form:"max_time" binding:"min=0"`
MinTime int64 `form:"min_time" binding:"min=0"`
}

// TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request
type TransactionListByMaxTimeRequest struct {
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
MaxTime int64 `form:"max_time" binding:"min=0"`
MinTime int64 `form:"min_time" binding:"min=0"`
Page int32 `form:"page" binding:"min=0"`
Count int32 `form:"count" binding:"required,min=1,max=50"`
WithCount bool `form:"with_count"`
WithPictures bool `form:"with_pictures"`
TrimAccount bool `form:"trim_account"`
TrimCategory bool `form:"trim_category"`
TrimTag bool `form:"trim_tag"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
TagFilterType TransactionTagFilterType `form:"tag_filter_type" binding:"min=0,max=3"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
MaxTime int64 `form:"max_time" binding:"min=0"`
MinTime int64 `form:"min_time" binding:"min=0"`
Page int32 `form:"page" binding:"min=0"`
Count int32 `form:"count" binding:"required,min=1,max=50"`
WithCount bool `form:"with_count"`
WithPictures bool `form:"with_pictures"`
TrimAccount bool `form:"trim_account"`
TrimCategory bool `form:"trim_category"`
TrimTag bool `form:"trim_tag"`
}

// TransactionListInMonthByPageRequest represents all parameters of transaction listing by month request
type TransactionListInMonthByPageRequest struct {
Year int32 `form:"year" binding:"required,min=1"`
Month int32 `form:"month" binding:"required,min=1"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
WithPictures bool `form:"with_pictures"`
TrimAccount bool `form:"trim_account"`
TrimCategory bool `form:"trim_category"`
TrimTag bool `form:"trim_tag"`
Year int32 `form:"year" binding:"required,min=1"`
Month int32 `form:"month" binding:"required,min=1"`
Type TransactionDbType `form:"type" binding:"min=0,max=4"`
CategoryIds string `form:"category_ids"`
AccountIds string `form:"account_ids"`
TagIds string `form:"tag_ids"`
TagFilterType TransactionTagFilterType `form:"tag_filter_type" binding:"min=0,max=3"`
AmountFilter string `form:"amount_filter" binding:"validAmountFilter"`
Keyword string `form:"keyword"`
WithPictures bool `form:"with_pictures"`
TrimAccount bool `form:"trim_account"`
TrimCategory bool `form:"trim_category"`
TrimTag bool `form:"trim_tag"`
}

// TransactionStatisticRequest represents all parameters of transaction statistic request
Expand Down
84 changes: 36 additions & 48 deletions pkg/services/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ func (s *TransactionService) GetAllTransactions(c core.Context, uid int64, pageC

// GetAllTransactionsByMaxTime returns all transactions before given time
func (s *TransactionService) GetAllTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, count int32, noDuplicated bool) ([]*models.Transaction, error) {
return s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, 0, 0, nil, nil, nil, false, "", "", 1, count, false, noDuplicated)
return s.GetTransactionsByMaxTime(c, uid, maxTransactionTime, 0, 0, nil, nil, nil, false, models.TRANSACTION_TAG_FILTER_HAS_ANY, "", "", 1, count, false, noDuplicated)
}

// GetTransactionsByMaxTime returns transactions before given time
func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) {
func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string, page int32, count int32, needOneMoreItem bool, noDuplicated bool) ([]*models.Transaction, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
Expand All @@ -103,22 +103,17 @@ func (s *TransactionService) GetTransactionsByMaxTime(c core.Context, uid int64,
actualCount++
}

condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated)
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, noDuplicated)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)

if len(tagIds) > 0 {
sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds))
} else if noTags {
sess.NotIn("transaction_id", s.getTransactionQueryByAllTagIdsCondition(uid, maxTransactionTime, minTransactionTime))
}
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)

err = sess.Limit(int(actualCount), int(count*(page-1))).OrderBy("transaction_time desc").Find(&transactions)

return transactions, err
}

// GetTransactionsInMonthByPage returns all transactions in given year and month
func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, amountFilter string, keyword string) ([]*models.Transaction, error) {
func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid int64, year int32, month int32, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) ([]*models.Transaction, error) {
if uid <= 0 {
return nil, errs.ErrUserIdInvalid
}
Expand All @@ -131,14 +126,9 @@ func (s *TransactionService) GetTransactionsInMonthByPage(c core.Context, uid in

var transactions []*models.Transaction

condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)

if len(tagIds) > 0 {
sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds))
} else if noTags {
sess.NotIn("transaction_id", s.getTransactionQueryByAllTagIdsCondition(uid, maxTransactionTime, minTransactionTime))
}
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)

err = sess.OrderBy("transaction_time desc").Find(&transactions)

Expand Down Expand Up @@ -181,23 +171,18 @@ func (s *TransactionService) GetTransactionByTransactionId(c core.Context, uid i

// GetAllTransactionCount returns total count of transactions
func (s *TransactionService) GetAllTransactionCount(c core.Context, uid int64) (int64, error) {
return s.GetTransactionCount(c, uid, 0, 0, 0, nil, nil, nil, false, "", "")
return s.GetTransactionCount(c, uid, 0, 0, 0, nil, nil, nil, false, models.TRANSACTION_TAG_FILTER_HAS_ANY, "", "")
}

// GetTransactionCount returns count of transactions
func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, amountFilter string, keyword string) (int64, error) {
func (s *TransactionService) GetTransactionCount(c core.Context, uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType, amountFilter string, keyword string) (int64, error) {
if uid <= 0 {
return 0, errs.ErrUserIdInvalid
}

condition, conditionParams := s.getTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
condition, conditionParams := s.buildTransactionQueryCondition(uid, maxTransactionTime, minTransactionTime, transactionType, categoryIds, accountIds, tagIds, amountFilter, keyword, true)
sess := s.UserDataDB(uid).NewSession(c).Where(condition, conditionParams...)

if len(tagIds) > 0 {
sess.In("transaction_id", s.getTransactionQueryByTagIdsCondition(uid, maxTransactionTime, minTransactionTime, tagIds))
} else if noTags {
sess.NotIn("transaction_id", s.getTransactionQueryByAllTagIdsCondition(uid, maxTransactionTime, minTransactionTime))
}
sess = s.appendFilterTagIdsConditionToQuery(sess, uid, maxTransactionTime, minTransactionTime, tagIds, noTags, tagFilterType)

return sess.Count(&models.Transaction{})
}
Expand Down Expand Up @@ -1753,7 +1738,7 @@ func (s *TransactionService) doCreateTransaction(sess *xorm.Session, transaction
return err
}

func (s *TransactionService) getTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) {
func (s *TransactionService) buildTransactionQueryCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, transactionType models.TransactionDbType, categoryIds []int64, accountIds []int64, tagIds []int64, amountFilter string, keyword string, noDuplicated bool) (string, []any) {
condition := "uid=? AND deleted=?"
conditionParams := make([]any, 0, 16)
conditionParams = append(conditionParams, uid)
Expand Down Expand Up @@ -1909,38 +1894,41 @@ func (s *TransactionService) getTransactionQueryCondition(uid int64, maxTransact
return condition, conditionParams
}

func (s *TransactionService) getTransactionQueryByTagIdsCondition(uid int64, maxTransactionTime int64, minTransactionTime int64, tagIds []int64) *builder.Builder {
if len(tagIds) > 0 {
condition := builder.And(builder.Eq{"uid": uid}, builder.Eq{"deleted": false})
func (s *TransactionService) appendFilterTagIdsConditionToQuery(sess *xorm.Session, uid int64, maxTransactionTime int64, minTransactionTime int64, tagIds []int64, noTags bool, tagFilterType models.TransactionTagFilterType) *xorm.Session {
subQueryCondition := builder.And(builder.Eq{"uid": uid}, builder.Eq{"deleted": false})

if maxTransactionTime > 0 {
condition = condition.And(builder.Lte{"transaction_time": maxTransactionTime})
}

if minTransactionTime > 0 {
condition = condition.And(builder.Gte{"transaction_time": minTransactionTime})
}
if maxTransactionTime > 0 {
subQueryCondition = subQueryCondition.And(builder.Lte{"transaction_time": maxTransactionTime})
}

condition = condition.And(builder.In("tag_id", tagIds))
if minTransactionTime > 0 {
subQueryCondition = subQueryCondition.And(builder.Gte{"transaction_time": minTransactionTime})
}

return builder.Select("transaction_id").From("transaction_tag_index").Where(condition)
if noTags {
subQuery := builder.Select("transaction_id").From("transaction_tag_index").Where(subQueryCondition)
sess.NotIn("transaction_id", subQuery)
return sess
}

return nil
}
if len(tagIds) < 1 {
return sess
}

func (s *TransactionService) getTransactionQueryByAllTagIdsCondition(uid int64, maxTransactionTime int64, minTransactionTime int64) *builder.Builder {
condition := builder.And(builder.Eq{"uid": uid}, builder.Eq{"deleted": false})
subQueryCondition = subQueryCondition.And(builder.In("tag_id", tagIds))
subQuery := builder.Select("transaction_id").From("transaction_tag_index").Where(subQueryCondition)

if maxTransactionTime > 0 {
condition = condition.And(builder.Lte{"transaction_time": maxTransactionTime})
if tagFilterType == models.TRANSACTION_TAG_FILTER_HAS_ALL || tagFilterType == models.TRANSACTION_TAG_FILTER_NOT_HAS_ALL {
subQuery = subQuery.GroupBy("transaction_id").Having(fmt.Sprintf("COUNT(DISTINCT tag_id) >= %d", len(tagIds)))
}

if minTransactionTime > 0 {
condition = condition.And(builder.Gte{"transaction_time": minTransactionTime})
if tagFilterType == models.TRANSACTION_TAG_FILTER_HAS_ANY || tagFilterType == models.TRANSACTION_TAG_FILTER_HAS_ALL {
sess.In("transaction_id", subQuery)
} else if tagFilterType == models.TRANSACTION_TAG_FILTER_NOT_HAS_ANY || tagFilterType == models.TRANSACTION_TAG_FILTER_NOT_HAS_ALL {
sess.NotIn("transaction_id", subQuery)
}

return builder.Select("transaction_id").From("transaction_tag_index").Where(condition)
return sess
}

func (s *TransactionService) isAccountIdValid(transaction *models.Transaction) error {
Expand Down
Loading

0 comments on commit dd35a85

Please sign in to comment.