mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-10-29 10:57:44 +09:00 
			
		
		
		
	
		
			
				
	
	
		
			357 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			357 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2016 The Xorm Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package xorm
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/go-xorm/builder"
 | |
| 	"github.com/go-xorm/core"
 | |
| )
 | |
| 
 | |
| func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
 | |
| 	if session.Statement.RefTable == nil ||
 | |
| 		session.Tx != nil {
 | |
| 		return ErrCacheFailed
 | |
| 	}
 | |
| 
 | |
| 	oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr)
 | |
| 	if newsql == "" {
 | |
| 		return ErrCacheFailed
 | |
| 	}
 | |
| 	for _, filter := range session.Engine.dialect.Filters() {
 | |
| 		newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable)
 | |
| 	}
 | |
| 	session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql)
 | |
| 
 | |
| 	var nStart int
 | |
| 	if len(args) > 0 {
 | |
| 		if strings.Index(sqlStr, "?") > -1 {
 | |
| 			nStart = strings.Count(oldhead, "?")
 | |
| 		} else {
 | |
| 			// only for pq, TODO: if any other databse?
 | |
| 			nStart = strings.Count(oldhead, "$")
 | |
| 		}
 | |
| 	}
 | |
| 	table := session.Statement.RefTable
 | |
| 	cacher := session.Engine.getCacher2(table)
 | |
| 	tableName := session.Statement.TableName()
 | |
| 	session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
 | |
| 	ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
 | |
| 	if err != nil {
 | |
| 		rows, err := session.DB().Query(newsql, args[nStart:]...)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		defer rows.Close()
 | |
| 
 | |
| 		ids = make([]core.PK, 0)
 | |
| 		for rows.Next() {
 | |
| 			var res = make([]string, len(table.PrimaryKeys))
 | |
| 			err = rows.ScanSlice(&res)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			var pk core.PK = make([]interface{}, len(table.PrimaryKeys))
 | |
| 			for i, col := range table.PKColumns() {
 | |
| 				if col.SQLType.IsNumeric() {
 | |
| 					n, err := strconv.ParseInt(res[i], 10, 64)
 | |
| 					if err != nil {
 | |
| 						return err
 | |
| 					}
 | |
| 					pk[i] = n
 | |
| 				} else if col.SQLType.IsText() {
 | |
| 					pk[i] = res[i]
 | |
| 				} else {
 | |
| 					return errors.New("not supported")
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			ids = append(ids, pk)
 | |
| 		}
 | |
| 		session.Engine.logger.Debug("[cacheUpdate] find updated id", ids)
 | |
| 	} /*else {
 | |
| 	    session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
 | |
| 	    cacher.DelIds(tableName, genSqlKey(newsql, args))
 | |
| 	}*/
 | |
| 
 | |
| 	for _, id := range ids {
 | |
| 		sid, err := id.ToString()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		if bean := cacher.GetBean(tableName, sid); bean != nil {
 | |
| 			sqls := splitNNoCase(sqlStr, "where", 2)
 | |
| 			if len(sqls) == 0 || len(sqls) > 2 {
 | |
| 				return ErrCacheFailed
 | |
| 			}
 | |
| 
 | |
| 			sqls = splitNNoCase(sqls[0], "set", 2)
 | |
| 			if len(sqls) != 2 {
 | |
| 				return ErrCacheFailed
 | |
| 			}
 | |
| 			kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
 | |
| 			for idx, kv := range kvs {
 | |
| 				sps := strings.SplitN(kv, "=", 2)
 | |
| 				sps2 := strings.Split(sps[0], ".")
 | |
| 				colName := sps2[len(sps2)-1]
 | |
| 				if strings.Contains(colName, "`") {
 | |
| 					colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
 | |
| 				} else if strings.Contains(colName, session.Engine.QuoteStr()) {
 | |
| 					colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1))
 | |
| 				} else {
 | |
| 					session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
 | |
| 					return ErrCacheFailed
 | |
| 				}
 | |
| 
 | |
| 				if col := table.GetColumn(colName); col != nil {
 | |
| 					fieldValue, err := col.ValueOf(bean)
 | |
| 					if err != nil {
 | |
| 						session.Engine.logger.Error(err)
 | |
| 					} else {
 | |
| 						session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
 | |
| 						if col.IsVersion && session.Statement.checkVersion {
 | |
| 							fieldValue.SetInt(fieldValue.Int() + 1)
 | |
| 						} else {
 | |
| 							fieldValue.Set(reflect.ValueOf(args[idx]))
 | |
| 						}
 | |
| 					}
 | |
| 				} else {
 | |
| 					session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's",
 | |
| 						colName, table.Name)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean)
 | |
| 			cacher.PutBean(tableName, sid, bean)
 | |
| 		}
 | |
| 	}
 | |
| 	session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName)
 | |
| 	cacher.ClearIds(tableName)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Update records, bean's non-empty fields are updated contents,
 | |
| // condiBean' non-empty filds are conditions
 | |
| // CAUTION:
 | |
| //        1.bool will defaultly be updated content nor conditions
 | |
| //         You should call UseBool if you have bool to use.
 | |
| //        2.float32 & float64 may be not inexact as conditions
 | |
| func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
 | |
| 	defer session.resetStatement()
 | |
| 	if session.IsAutoClose {
 | |
| 		defer session.Close()
 | |
| 	}
 | |
| 
 | |
| 	v := rValue(bean)
 | |
| 	t := v.Type()
 | |
| 
 | |
| 	var colNames []string
 | |
| 	var args []interface{}
 | |
| 
 | |
| 	// handle before update processors
 | |
| 	for _, closure := range session.beforeClosures {
 | |
| 		closure(bean)
 | |
| 	}
 | |
| 	cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
 | |
| 	if processor, ok := interface{}(bean).(BeforeUpdateProcessor); ok {
 | |
| 		processor.BeforeUpdate()
 | |
| 	}
 | |
| 	// --
 | |
| 
 | |
| 	var err error
 | |
| 	var isMap = t.Kind() == reflect.Map
 | |
| 	var isStruct = t.Kind() == reflect.Struct
 | |
| 	if isStruct {
 | |
| 		session.Statement.setRefValue(v)
 | |
| 
 | |
| 		if len(session.Statement.TableName()) <= 0 {
 | |
| 			return 0, ErrTableNotFound
 | |
| 		}
 | |
| 
 | |
| 		if session.Statement.ColumnStr == "" {
 | |
| 			colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false,
 | |
| 				false, false, session.Statement.allUseBool, session.Statement.useAllCols,
 | |
| 				session.Statement.mustColumnMap, session.Statement.nullableMap,
 | |
| 				session.Statement.columnMap, true, session.Statement.unscoped)
 | |
| 		} else {
 | |
| 			colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true)
 | |
| 			if err != nil {
 | |
| 				return 0, err
 | |
| 			}
 | |
| 		}
 | |
| 	} else if isMap {
 | |
| 		colNames = make([]string, 0)
 | |
| 		args = make([]interface{}, 0)
 | |
| 		bValue := reflect.Indirect(reflect.ValueOf(bean))
 | |
| 
 | |
| 		for _, v := range bValue.MapKeys() {
 | |
| 			colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
 | |
| 			args = append(args, bValue.MapIndex(v).Interface())
 | |
| 		}
 | |
| 	} else {
 | |
| 		return 0, ErrParamsType
 | |
| 	}
 | |
| 
 | |
| 	table := session.Statement.RefTable
 | |
| 
 | |
| 	if session.Statement.UseAutoTime && table != nil && table.Updated != "" {
 | |
| 		colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
 | |
| 		col := table.UpdatedColumn()
 | |
| 		val, t := session.Engine.NowTime2(col.SQLType.Name)
 | |
| 		args = append(args, val)
 | |
| 
 | |
| 		var colName = col.Name
 | |
| 		if isStruct {
 | |
| 			session.afterClosures = append(session.afterClosures, func(bean interface{}) {
 | |
| 				col := table.GetColumn(colName)
 | |
| 				setColumnTime(bean, col, t)
 | |
| 			})
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	//for update action to like "column = column + ?"
 | |
| 	incColumns := session.Statement.getInc()
 | |
| 	for _, v := range incColumns {
 | |
| 		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?")
 | |
| 		args = append(args, v.arg)
 | |
| 	}
 | |
| 	//for update action to like "column = column - ?"
 | |
| 	decColumns := session.Statement.getDec()
 | |
| 	for _, v := range decColumns {
 | |
| 		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?")
 | |
| 		args = append(args, v.arg)
 | |
| 	}
 | |
| 	//for update action to like "column = expression"
 | |
| 	exprColumns := session.Statement.getExpr()
 | |
| 	for _, v := range exprColumns {
 | |
| 		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
 | |
| 	}
 | |
| 
 | |
| 	session.Statement.processIDParam()
 | |
| 
 | |
| 	var autoCond builder.Cond
 | |
| 	if !session.Statement.noAutoCondition && len(condiBean) > 0 {
 | |
| 		var err error
 | |
| 		autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false)
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	st := session.Statement
 | |
| 	defer session.resetStatement()
 | |
| 
 | |
| 	var sqlStr string
 | |
| 	var condArgs []interface{}
 | |
| 	var condSQL string
 | |
| 	cond := session.Statement.cond.And(autoCond)
 | |
| 
 | |
| 	var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion)
 | |
| 	var verValue *reflect.Value
 | |
| 	if doIncVer {
 | |
| 		verValue, err = table.VersionColumn().ValueOf(bean)
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 
 | |
| 		cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
 | |
| 		colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
 | |
| 	}
 | |
| 
 | |
| 	condSQL, condArgs, _ = builder.ToSQL(cond)
 | |
| 	if len(condSQL) > 0 {
 | |
| 		condSQL = "WHERE " + condSQL
 | |
| 	}
 | |
| 
 | |
| 	if st.OrderStr != "" {
 | |
| 		condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr)
 | |
| 	}
 | |
| 
 | |
| 	// TODO: Oracle support needed
 | |
| 	var top string
 | |
| 	if st.LimitN > 0 {
 | |
| 		if st.Engine.dialect.DBType() == core.MYSQL {
 | |
| 			condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
 | |
| 		} else if st.Engine.dialect.DBType() == core.SQLITE {
 | |
| 			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
 | |
| 			cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
 | |
| 				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
 | |
| 			condSQL, condArgs, _ = builder.ToSQL(cond)
 | |
| 			if len(condSQL) > 0 {
 | |
| 				condSQL = "WHERE " + condSQL
 | |
| 			}
 | |
| 		} else if st.Engine.dialect.DBType() == core.POSTGRES {
 | |
| 			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
 | |
| 			cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
 | |
| 				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
 | |
| 			condSQL, condArgs, _ = builder.ToSQL(cond)
 | |
| 			if len(condSQL) > 0 {
 | |
| 				condSQL = "WHERE " + condSQL
 | |
| 			}
 | |
| 		} else if st.Engine.dialect.DBType() == core.MSSQL {
 | |
| 			top = fmt.Sprintf("top (%d) ", st.LimitN)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
 | |
| 		top,
 | |
| 		session.Engine.Quote(session.Statement.TableName()),
 | |
| 		strings.Join(colNames, ", "),
 | |
| 		condSQL)
 | |
| 
 | |
| 	res, err := session.exec(sqlStr, append(args, condArgs...)...)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	} else if doIncVer {
 | |
| 		if verValue != nil && verValue.IsValid() && verValue.CanSet() {
 | |
| 			verValue.SetInt(verValue.Int() + 1)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if table != nil {
 | |
| 		if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
 | |
| 			cacher.ClearIds(session.Statement.TableName())
 | |
| 			cacher.ClearBeans(session.Statement.TableName())
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// handle after update processors
 | |
| 	if session.IsAutoCommit {
 | |
| 		for _, closure := range session.afterClosures {
 | |
| 			closure(bean)
 | |
| 		}
 | |
| 		if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok {
 | |
| 			session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor")
 | |
| 			processor.AfterUpdate()
 | |
| 		}
 | |
| 	} else {
 | |
| 		lenAfterClosures := len(session.afterClosures)
 | |
| 		if lenAfterClosures > 0 {
 | |
| 			if value, has := session.afterUpdateBeans[bean]; has && value != nil {
 | |
| 				*value = append(*value, session.afterClosures...)
 | |
| 			} else {
 | |
| 				afterClosures := make([]func(interface{}), lenAfterClosures)
 | |
| 				copy(afterClosures, session.afterClosures)
 | |
| 				// FIXME: if bean is a map type, it will panic because map cannot be as map key
 | |
| 				session.afterUpdateBeans[bean] = &afterClosures
 | |
| 			}
 | |
| 
 | |
| 		} else {
 | |
| 			if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
 | |
| 				session.afterUpdateBeans[bean] = nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
 | |
| 	// --
 | |
| 
 | |
| 	return res.RowsAffected()
 | |
| }
 |