drakkan / sftpgo

@@ -6,6 +6,7 @@
Loading
6 6
	"os"
7 7
	"path"
8 8
	"strings"
9 +
	"sync"
9 10
	"sync/atomic"
10 11
	"time"
11 12
@@ -113,7 +114,7 @@
Loading
113 114
	}
114 115
115 116
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload,
116 -
		0, 0, 0, false, fs)
117 +
		0, 0, 0, 0, false, fs)
117 118
	return newHTTPDFile(baseTransfer, nil, r), nil
118 119
}
119 120
@@ -190,6 +191,7 @@
Loading
190 191
	}
191 192
192 193
	initialSize := int64(0)
194 +
	truncatedSize := int64(0) // bytes truncated and not included in quota
193 195
	if !isNewFile {
194 196
		if vfs.IsLocalOrSFTPFs(fs) {
195 197
			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
@@ -203,6 +205,7 @@
Loading
203 205
			}
204 206
		} else {
205 207
			initialSize = fileSize
208 +
			truncatedSize = fileSize
206 209
		}
207 210
		if maxWriteSize > 0 {
208 211
			maxWriteSize += fileSize
@@ -212,7 +215,7 @@
Loading
212 215
	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
213 216
214 217
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
215 -
		common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
218 +
		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
216 219
	return newHTTPDFile(baseTransfer, w, nil), nil
217 220
}
218 221
@@ -232,15 +235,17 @@
Loading
232 235
233 236
type throttledReader struct {
234 237
	bytesRead     int64
235 -
	id            uint64
238 +
	id            int64
236 239
	limit         int64
237 240
	r             io.ReadCloser
238 241
	abortTransfer int32
239 242
	start         time.Time
240 243
	conn          *Connection
244 +
	mu            sync.Mutex
245 +
	errAbort      error
241 246
}
242 247
243 -
func (t *throttledReader) GetID() uint64 {
248 +
func (t *throttledReader) GetID() int64 {
244 249
	return t.id
245 250
}
246 251
@@ -252,6 +257,14 @@
Loading
252 257
	return atomic.LoadInt64(&t.bytesRead)
253 258
}
254 259
260 +
func (t *throttledReader) GetDownloadedSize() int64 {
261 +
	return 0
262 +
}
263 +
264 +
func (t *throttledReader) GetUploadedSize() int64 {
265 +
	return atomic.LoadInt64(&t.bytesRead)
266 +
}
267 +
255 268
func (t *throttledReader) GetVirtualPath() string {
256 269
	return "**reading request body**"
257 270
}
@@ -260,10 +273,31 @@
Loading
260 273
	return t.start
261 274
}
262 275
263 -
func (t *throttledReader) SignalClose() {
276 +
func (t *throttledReader) GetAbortError() error {
277 +
	t.mu.Lock()
278 +
	defer t.mu.Unlock()
279 +
280 +
	if t.errAbort != nil {
281 +
		return t.errAbort
282 +
	}
283 +
	return common.ErrTransferAborted
284 +
}
285 +
286 +
func (t *throttledReader) SignalClose(err error) {
287 +
	t.mu.Lock()
288 +
	t.errAbort = err
289 +
	t.mu.Unlock()
264 290
	atomic.StoreInt32(&(t.abortTransfer), 1)
265 291
}
266 292
293 +
func (t *throttledReader) GetTruncatedSize() int64 {
294 +
	return 0
295 +
}
296 +
297 +
func (t *throttledReader) GetMaxAllowedSize() int64 {
298 +
	return 0
299 +
}
300 +
267 301
func (t *throttledReader) Truncate(fsPath string, size int64) (int64, error) {
268 302
	return 0, vfs.ErrVfsUnsupported
269 303
}
@@ -278,7 +312,7 @@
Loading
278 312
279 313
func (t *throttledReader) Read(p []byte) (n int, err error) {
280 314
	if atomic.LoadInt32(&t.abortTransfer) == 1 {
281 -
		return 0, errTransferAborted
315 +
		return 0, t.GetAbortError()
282 316
	}
283 317
284 318
	t.conn.UpdateLastActivity()

@@ -1,7 +1,6 @@
Loading
1 1
package httpd
2 2
3 3
import (
4 -
	"errors"
5 4
	"io"
6 5
	"sync/atomic"
7 6
@@ -11,8 +10,6 @@
Loading
11 10
	"github.com/drakkan/sftpgo/v2/vfs"
12 11
)
13 12
14 -
var errTransferAborted = errors.New("transfer aborted")
15 -
16 13
type httpdFile struct {
17 14
	*common.BaseTransfer
18 15
	writer     io.WriteCloser
@@ -42,7 +39,9 @@
Loading
42 39
// Read reads the contents to downloads.
43 40
func (f *httpdFile) Read(p []byte) (n int, err error) {
44 41
	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
45 -
		return 0, errTransferAborted
42 +
		err := f.GetAbortError()
43 +
		f.TransferError(err)
44 +
		return 0, err
46 45
	}
47 46
48 47
	f.Connection.UpdateLastActivity()
@@ -61,7 +60,9 @@
Loading
61 60
// Write writes the contents to upload
62 61
func (f *httpdFile) Write(p []byte) (n int, err error) {
63 62
	if atomic.LoadInt32(&f.AbortTransfer) == 1 {
64 -
		return 0, errTransferAborted
63 +
		err := f.GetAbortError()
64 +
		f.TransferError(err)
65 +
		return 0, err
65 66
	}
66 67
67 68
	f.Connection.UpdateLastActivity()

@@ -149,8 +149,8 @@
Loading
149 149
		}
150 150
	}
151 151
152 -
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, common.TransferDownload,
153 -
		0, 0, 0, false, fs)
152 +
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath,
153 +
		common.TransferDownload, 0, 0, 0, 0, false, fs)
154 154
155 155
	return newWebDavFile(baseTransfer, nil, r), nil
156 156
}
@@ -214,7 +214,7 @@
Loading
214 214
	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
215 215
216 216
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
217 -
		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
217 +
		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
218 218
219 219
	return newWebDavFile(baseTransfer, w, nil), nil
220 220
}
@@ -252,6 +252,7 @@
Loading
252 252
		return nil, c.GetFsError(fs, err)
253 253
	}
254 254
	initialSize := int64(0)
255 +
	truncatedSize := int64(0) // bytes truncated and not included in quota
255 256
	if vfs.IsLocalOrSFTPFs(fs) {
256 257
		vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath))
257 258
		if err == nil {
@@ -264,12 +265,13 @@
Loading
264 265
		}
265 266
	} else {
266 267
		initialSize = fileSize
268 +
		truncatedSize = fileSize
267 269
	}
268 270
269 271
	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
270 272
271 273
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
272 -
		common.TransferUpload, 0, initialSize, maxWriteSize, false, fs)
274 +
		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs)
273 275
274 276
	return newWebDavFile(baseTransfer, w, nil), nil
275 277
}

@@ -85,7 +85,7 @@
Loading
85 85
	}
86 86
87 87
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload,
88 -
		0, 0, 0, false, fs)
88 +
		0, 0, 0, 0, false, fs)
89 89
	t := newTransfer(baseTransfer, nil, r, nil)
90 90
91 91
	return t, nil
@@ -364,7 +364,7 @@
Loading
364 364
	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
365 365
366 366
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
367 -
		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
367 +
		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
368 368
	t := newTransfer(baseTransfer, w, nil, errForRead)
369 369
370 370
	return t, nil
@@ -415,6 +415,7 @@
Loading
415 415
	}
416 416
417 417
	initialSize := int64(0)
418 +
	truncatedSize := int64(0) // bytes truncated and not included in quota
418 419
	if isResume {
419 420
		c.Log(logger.LevelDebug, "resuming upload requested, file path %#v initial size: %v has append flag %v",
420 421
			filePath, fileSize, pflags.Append)
@@ -436,13 +437,14 @@
Loading
436 437
			}
437 438
		} else {
438 439
			initialSize = fileSize
440 +
			truncatedSize = fileSize
439 441
		}
440 442
	}
441 443
442 444
	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
443 445
444 446
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
445 -
		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
447 +
		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
446 448
	t := newTransfer(baseTransfer, w, nil, errForRead)
447 449
448 450
	return t, nil

@@ -356,7 +356,7 @@
Loading
356 356
	go func() {
357 357
		defer stdin.Close()
358 358
		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
359 -
			common.TransferUpload, 0, 0, remainingQuotaSize, false, command.fs)
359 +
			common.TransferUpload, 0, 0, remainingQuotaSize, 0, false, command.fs)
360 360
		transfer := newTransfer(baseTransfer, nil, nil, nil)
361 361
362 362
		w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
@@ -369,7 +369,7 @@
Loading
369 369
370 370
	go func() {
371 371
		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
372 -
			common.TransferDownload, 0, 0, 0, false, command.fs)
372 +
			common.TransferDownload, 0, 0, 0, 0, false, command.fs)
373 373
		transfer := newTransfer(baseTransfer, nil, nil, nil)
374 374
375 375
		w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
@@ -383,7 +383,7 @@
Loading
383 383
384 384
	go func() {
385 385
		baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, command.fsPath, sshDestPath,
386 -
			common.TransferDownload, 0, 0, 0, false, command.fs)
386 +
			common.TransferDownload, 0, 0, 0, 0, false, command.fs)
387 387
		transfer := newTransfer(baseTransfer, nil, nil, nil)
388 388
389 389
		w, e := transfer.copyFromReaderToWriter(c.connection.channel.(ssh.Channel).Stderr(), stderr)

@@ -238,6 +238,7 @@
Loading
238 238
	}
239 239
240 240
	initialSize := int64(0)
241 +
	truncatedSize := int64(0) // bytes truncated and not included in quota
241 242
	if !isNewFile {
242 243
		if vfs.IsLocalOrSFTPFs(fs) {
243 244
			vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath))
@@ -251,6 +252,7 @@
Loading
251 252
			}
252 253
		} else {
253 254
			initialSize = fileSize
255 +
			truncatedSize = initialSize
254 256
		}
255 257
		if maxWriteSize > 0 {
256 258
			maxWriteSize += fileSize
@@ -260,7 +262,7 @@
Loading
260 262
	vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID())
261 263
262 264
	baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
263 -
		common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, fs)
265 +
		common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs)
264 266
	t := newTransfer(baseTransfer, w, nil, nil)
265 267
266 268
	return c.getUploadFileData(sizeToRead, t)
@@ -529,7 +531,7 @@
Loading
529 531
	}
530 532
531 533
	baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath,
532 -
		common.TransferDownload, 0, 0, 0, false, fs)
534 +
		common.TransferDownload, 0, 0, 0, 0, false, fs)
533 535
	t := newTransfer(baseTransfer, nil, r, nil)
534 536
535 537
	err = c.sendDownloadFileData(fs, p, stat, t)

@@ -0,0 +1,167 @@
Loading
1 +
package common
2 +
3 +
import (
4 +
	"errors"
5 +
	"sync"
6 +
	"time"
7 +
8 +
	"github.com/drakkan/sftpgo/v2/dataprovider"
9 +
	"github.com/drakkan/sftpgo/v2/logger"
10 +
	"github.com/drakkan/sftpgo/v2/util"
11 +
)
12 +
13 +
type overquotaTransfer struct {
14 +
	ConnID     string
15 +
	TransferID int64
16 +
}
17 +
18 +
// TransfersChecker defines the interface that transfer checkers must implement.
19 +
// A transfer checker ensure that multiple concurrent transfers does not exceeded
20 +
// the remaining user quota
21 +
type TransfersChecker interface {
22 +
	AddTransfer(transfer dataprovider.ActiveTransfer)
23 +
	RemoveTransfer(ID int64, connectionID string)
24 +
	UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string)
25 +
	GetOverquotaTransfers() []overquotaTransfer
26 +
}
27 +
28 +
func getTransfersChecker() TransfersChecker {
29 +
	return &transfersCheckerMem{}
30 +
}
31 +
32 +
type transfersCheckerMem struct {
33 +
	sync.RWMutex
34 +
	transfers []dataprovider.ActiveTransfer
35 +
}
36 +
37 +
func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) {
38 +
	t.Lock()
39 +
	defer t.Unlock()
40 +
41 +
	t.transfers = append(t.transfers, transfer)
42 +
}
43 +
44 +
func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) {
45 +
	t.Lock()
46 +
	defer t.Unlock()
47 +
48 +
	for idx, transfer := range t.transfers {
49 +
		if transfer.ID == ID && transfer.ConnID == connectionID {
50 +
			lastIdx := len(t.transfers) - 1
51 +
			t.transfers[idx] = t.transfers[lastIdx]
52 +
			t.transfers = t.transfers[:lastIdx]
53 +
			return
54 +
		}
55 +
	}
56 +
}
57 +
58 +
func (t *transfersCheckerMem) UpdateTransferCurrentSize(ulSize int64, dlSize int64, ID int64, connectionID string) {
59 +
	t.Lock()
60 +
	defer t.Unlock()
61 +
62 +
	for idx := range t.transfers {
63 +
		if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID {
64 +
			t.transfers[idx].CurrentDLSize = dlSize
65 +
			t.transfers[idx].CurrentULSize = ulSize
66 +
			t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
67 +
			return
68 +
		}
69 +
	}
70 +
}
71 +
72 +
func (t *transfersCheckerMem) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) {
73 +
	var result int64
74 +
75 +
	if folderName != "" {
76 +
		for _, folder := range user.VirtualFolders {
77 +
			if folder.Name == folderName {
78 +
				if folder.QuotaSize > 0 {
79 +
					return folder.QuotaSize - folder.UsedQuotaSize, nil
80 +
				}
81 +
			}
82 +
		}
83 +
	} else {
84 +
		if user.QuotaSize > 0 {
85 +
			return user.QuotaSize - user.UsedQuotaSize, nil
86 +
		}
87 +
	}
88 +
89 +
	return result, errors.New("no quota limit defined")
90 +
}
91 +
92 +
func (t *transfersCheckerMem) aggregateTransfers() (map[string]bool, map[string][]dataprovider.ActiveTransfer) {
93 +
	t.RLock()
94 +
	defer t.RUnlock()
95 +
96 +
	usersToFetch := make(map[string]bool)
97 +
	aggregations := make(map[string][]dataprovider.ActiveTransfer)
98 +
	for _, transfer := range t.transfers {
99 +
		key := transfer.GetKey()
100 +
		aggregations[key] = append(aggregations[key], transfer)
101 +
		if len(aggregations[key]) > 1 {
102 +
			if transfer.FolderName != "" {
103 +
				usersToFetch[transfer.Username] = true
104 +
			} else {
105 +
				if _, ok := usersToFetch[transfer.Username]; !ok {
106 +
					usersToFetch[transfer.Username] = false
107 +
				}
108 +
			}
109 +
		}
110 +
	}
111 +
112 +
	return usersToFetch, aggregations
113 +
}
114 +
115 +
func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer {
116 +
	usersToFetch, aggregations := t.aggregateTransfers()
117 +
118 +
	if len(usersToFetch) == 0 {
119 +
		return nil
120 +
	}
121 +
122 +
	users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch)
123 +
	if err != nil {
124 +
		logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err)
125 +
		return nil
126 +
	}
127 +
128 +
	usersMap := make(map[string]dataprovider.User)
129 +
130 +
	for _, user := range users {
131 +
		usersMap[user.Username] = user
132 +
	}
133 +
134 +
	var overquotaTransfers []overquotaTransfer
135 +
136 +
	for _, transfers := range aggregations {
137 +
		if len(transfers) > 1 {
138 +
			username := transfers[0].Username
139 +
			folderName := transfers[0].FolderName
140 +
			// transfer type is always upload for now
141 +
			remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName)
142 +
			if err != nil {
143 +
				continue
144 +
			}
145 +
			var usedDiskQuota int64
146 +
			for _, tr := range transfers {
147 +
				// We optimistically assume that a cloud transfer that replaces an existing
148 +
				// file will be successful
149 +
				usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize
150 +
			}
151 +
			logger.Debug(logSender, "", "username %#v, folder %#v, concurrent transfers: %v, remaining disk quota: %v, disk quota used in ongoing transfers: %v",
152 +
				username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota)
153 +
			if usedDiskQuota > remaningDiskQuota {
154 +
				for _, tr := range transfers {
155 +
					if tr.CurrentULSize > tr.TruncatedSize {
156 +
						overquotaTransfers = append(overquotaTransfers, overquotaTransfer{
157 +
							ConnID:     tr.ConnID,
158 +
							TransferID: tr.ID,
159 +
						})
160 +
					}
161 +
				}
162 +
			}
163 +
		}
164 +
	}
165 +
166 +
	return overquotaTransfers
167 +
}

@@ -27,7 +27,7 @@
Loading
27 27
	lastActivity int64
28 28
	// unique ID for a transfer.
29 29
	// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
30 -
	transferID uint64
30 +
	transferID int64
31 31
	// Unique identifier for the connection
32 32
	ID string
33 33
	// user associated with this connection if any
@@ -66,8 +66,8 @@
Loading
66 66
}
67 67
68 68
// GetTransferID returns an unique transfer ID for this connection
69 -
func (c *BaseConnection) GetTransferID() uint64 {
70 -
	return atomic.AddUint64(&c.transferID, 1)
69 +
func (c *BaseConnection) GetTransferID() int64 {
70 +
	return atomic.AddInt64(&c.transferID, 1)
71 71
}
72 72
73 73
// GetID returns the connection ID
@@ -125,6 +125,27 @@
Loading
125 125
126 126
	c.activeTransfers = append(c.activeTransfers, t)
127 127
	c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers))
128 +
	if t.GetMaxAllowedSize() > 0 {
129 +
		folderName := ""
130 +
		if t.GetType() == TransferUpload {
131 +
			vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath()))
132 +
			if err == nil {
133 +
				if !vfolder.IsIncludedInUserQuota() {
134 +
					folderName = vfolder.Name
135 +
				}
136 +
			}
137 +
		}
138 +
		go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{
139 +
			ID:            t.GetID(),
140 +
			Type:          t.GetType(),
141 +
			ConnID:        c.ID,
142 +
			Username:      c.GetUsername(),
143 +
			FolderName:    folderName,
144 +
			TruncatedSize: t.GetTruncatedSize(),
145 +
			CreatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
146 +
			UpdatedAt:     util.GetTimeAsMsSinceEpoch(time.Now()),
147 +
		})
148 +
	}
128 149
}
129 150
130 151
// RemoveTransfer removes the specified transfer from the active ones
@@ -132,6 +153,10 @@
Loading
132 153
	c.Lock()
133 154
	defer c.Unlock()
134 155
156 +
	if t.GetMaxAllowedSize() > 0 {
157 +
		go transfersChecker.RemoveTransfer(t.GetID(), c.ID)
158 +
	}
159 +
135 160
	for idx, transfer := range c.activeTransfers {
136 161
		if transfer.GetID() == t.GetID() {
137 162
			lastIdx := len(c.activeTransfers) - 1
@@ -145,6 +170,20 @@
Loading
145 170
	c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID())
146 171
}
147 172
173 +
// SignalTransferClose makes the transfer fail on the next read/write with the
174 +
// specified error
175 +
func (c *BaseConnection) SignalTransferClose(transferID int64, err error) {
176 +
	c.RLock()
177 +
	defer c.RUnlock()
178 +
179 +
	for _, t := range c.activeTransfers {
180 +
		if t.GetID() == transferID {
181 +
			c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID)
182 +
			t.SignalClose(err)
183 +
		}
184 +
	}
185 +
}
186 +
148 187
// GetTransfers returns the active transfers
149 188
func (c *BaseConnection) GetTransfers() []ConnectionTransfer {
150 189
	c.RLock()
@@ -160,11 +199,14 @@
Loading
160 199
			operationType = operationUpload
161 200
		}
162 201
		transfers = append(transfers, ConnectionTransfer{
163 -
			ID:            t.GetID(),
164 -
			OperationType: operationType,
165 -
			StartTime:     util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
166 -
			Size:          t.GetSize(),
167 -
			VirtualPath:   t.GetVirtualPath(),
202 +
			ID:             t.GetID(),
203 +
			OperationType:  operationType,
204 +
			StartTime:      util.GetTimeAsMsSinceEpoch(t.GetStartTime()),
205 +
			Size:           t.GetSize(),
206 +
			VirtualPath:    t.GetVirtualPath(),
207 +
			MaxAllowedSize: t.GetMaxAllowedSize(),
208 +
			ULSize:         t.GetUploadedSize(),
209 +
			DLSize:         t.GetDownloadedSize(),
168 210
		})
169 211
	}
170 212
@@ -181,7 +223,7 @@
Loading
181 223
	}
182 224
183 225
	for _, t := range c.activeTransfers {
184 -
		t.SignalClose()
226 +
		t.SignalClose(ErrTransferAborted)
185 227
	}
186 228
	return nil
187 229
}
@@ -1208,9 +1250,8 @@
Loading
1208 1250
	}
1209 1251
}
1210 1252
1211 -
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
1212 -
func (c *BaseConnection) GetQuotaExceededError() error {
1213 -
	switch c.protocol {
1253 +
func getQuotaExceededError(protocol string) error {
1254 +
	switch protocol {
1214 1255
	case ProtocolSFTP:
1215 1256
		return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrQuotaExceeded.Error())
1216 1257
	case ProtocolFTP:
@@ -1220,6 +1261,11 @@
Loading
1220 1261
	}
1221 1262
}
1222 1263
1264 +
// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol
1265 +
func (c *BaseConnection) GetQuotaExceededError() error {
1266 +
	return getQuotaExceededError(c.protocol)
1267 +
}
1268 +
1223 1269
// IsQuotaExceededError returns true if the given error is a quota exceeded error
1224 1270
func (c *BaseConnection) IsQuotaExceededError(err error) bool {
1225 1271
	switch c.protocol {

@@ -20,7 +20,7 @@
Loading
20 20
21 21
// BaseTransfer contains protocols common transfer details for an upload or a download.
22 22
type BaseTransfer struct { //nolint:maligned
23 -
	ID              uint64
23 +
	ID              int64
24 24
	BytesSent       int64
25 25
	BytesReceived   int64
26 26
	Fs              vfs.Fs
@@ -35,18 +35,21 @@
Loading
35 35
	MaxWriteSize    int64
36 36
	MinWriteOffset  int64
37 37
	InitialSize     int64
38 +
	truncatedSize   int64
38 39
	isNewFile       bool
39 40
	transferType    int
40 41
	AbortTransfer   int32
41 42
	aTime           time.Time
42 43
	mTime           time.Time
43 44
	sync.Mutex
45 +
	errAbort    error
44 46
	ErrTransfer error
45 47
}
46 48
47 49
// NewBaseTransfer returns a new BaseTransfer and adds it to the given connection
48 50
func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string,
49 -
	transferType int, minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer {
51 +
	transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs,
52 +
) *BaseTransfer {
50 53
	t := &BaseTransfer{
51 54
		ID:              conn.GetTransferID(),
52 55
		File:            file,
@@ -64,6 +67,7 @@
Loading
64 67
		BytesReceived:   0,
65 68
		MaxWriteSize:    maxWriteSize,
66 69
		AbortTransfer:   0,
70 +
		truncatedSize:   truncatedSize,
67 71
		Fs:              fs,
68 72
	}
69 73
@@ -77,7 +81,7 @@
Loading
77 81
}
78 82
79 83
// GetID returns the transfer ID
80 -
func (t *BaseTransfer) GetID() uint64 {
84 +
func (t *BaseTransfer) GetID() int64 {
81 85
	return t.ID
82 86
}
83 87
@@ -94,19 +98,53 @@
Loading
94 98
	return atomic.LoadInt64(&t.BytesReceived)
95 99
}
96 100
101 +
// GetDownloadedSize returns the transferred size
102 +
func (t *BaseTransfer) GetDownloadedSize() int64 {
103 +
	return atomic.LoadInt64(&t.BytesSent)
104 +
}
105 +
106 +
// GetUploadedSize returns the transferred size
107 +
func (t *BaseTransfer) GetUploadedSize() int64 {
108 +
	return atomic.LoadInt64(&t.BytesReceived)
109 +
}
110 +
97 111
// GetStartTime returns the start time
98 112
func (t *BaseTransfer) GetStartTime() time.Time {
99 113
	return t.start
100 114
}
101 115
102 -
// SignalClose signals that the transfer should be closed.
103 -
// For same protocols, for example WebDAV, we have no
104 -
// access to the network connection, so we use this method
105 -
// to make the next read or write to fail
106 -
func (t *BaseTransfer) SignalClose() {
116 +
// GetAbortError returns the error to send to the client if the transfer was aborted
117 +
func (t *BaseTransfer) GetAbortError() error {
118 +
	t.Lock()
119 +
	defer t.Unlock()
120 +
121 +
	if t.errAbort != nil {
122 +
		return t.errAbort
123 +
	}
124 +
	return getQuotaExceededError(t.Connection.protocol)
125 +
}
126 +
127 +
// SignalClose signals that the transfer should be closed after the next read/write.
128 +
// The optional error argument allow to send a specific error, otherwise a generic
129 +
// transfer aborted error is sent
130 +
func (t *BaseTransfer) SignalClose(err error) {
131 +
	t.Lock()
132 +
	t.errAbort = err
133 +
	t.Unlock()
107 134
	atomic.StoreInt32(&(t.AbortTransfer), 1)
108 135
}
109 136
137 +
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
138 +
// an existing file
139 +
func (t *BaseTransfer) GetTruncatedSize() int64 {
140 +
	return t.truncatedSize
141 +
}
142 +
143 +
// GetMaxAllowedSize returns the max allowed size
144 +
func (t *BaseTransfer) GetMaxAllowedSize() int64 {
145 +
	return t.MaxWriteSize
146 +
}
147 +
110 148
// GetVirtualPath returns the transfer virtual path
111 149
func (t *BaseTransfer) GetVirtualPath() string {
112 150
	return t.requestPath

@@ -53,9 +53,10 @@
Loading
53 53
	operationMkdir     = "mkdir"
54 54
	operationRmdir     = "rmdir"
55 55
	// SSH command action name
56 -
	OperationSSHCmd          = "ssh_cmd"
57 -
	chtimesFormat            = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
58 -
	idleTimeoutCheckInterval = 3 * time.Minute
56 +
	OperationSSHCmd              = "ssh_cmd"
57 +
	chtimesFormat                = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
58 +
	idleTimeoutCheckInterval     = 3 * time.Minute
59 +
	periodicTimeoutCheckInterval = 1 * time.Minute
59 60
)
60 61
61 62
// Stat flags
@@ -110,6 +111,7 @@
Loading
110 111
	ErrCrtRevoked           = errors.New("your certificate has been revoked")
111 112
	ErrNoCredentials        = errors.New("no credential provided")
112 113
	ErrInternalFailure      = errors.New("internal failure")
114 +
	ErrTransferAborted      = errors.New("transfer aborted")
113 115
	errNoTransfer           = errors.New("requested transfer not found")
114 116
	errTransferMismatch     = errors.New("transfer mismatch")
115 117
)
@@ -120,10 +122,11 @@
Loading
120 122
	// Connections is the list of active connections
121 123
	Connections ActiveConnections
122 124
	// QuotaScans is the list of active quota scans
123 -
	QuotaScans            ActiveScans
124 -
	idleTimeoutTicker     *time.Ticker
125 -
	idleTimeoutTickerDone chan bool
126 -
	supportedProtocols    = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
125 +
	QuotaScans                ActiveScans
126 +
	transfersChecker          TransfersChecker
127 +
	periodicTimeoutTicker     *time.Ticker
128 +
	periodicTimeoutTickerDone chan bool
129 +
	supportedProtocols        = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
127 130
		ProtocolHTTP, ProtocolHTTPShare}
128 131
	disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
129 132
	// the map key is the protocol, for each protocol we can have multiple rate limiters
@@ -135,9 +138,7 @@
Loading
135 138
	Config = c
136 139
	Config.idleLoginTimeout = 2 * time.Minute
137 140
	Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
138 -
	if Config.IdleTimeout > 0 {
139 -
		startIdleTimeoutTicker(idleTimeoutCheckInterval)
140 -
	}
141 +
	startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
141 142
	Config.defender = nil
142 143
	rateLimiters = make(map[string][]*rateLimiter)
143 144
	for _, rlCfg := range c.RateLimitersConfig {
@@ -176,6 +177,7 @@
Loading
176 177
	}
177 178
	vfs.SetTempPath(c.TempPath)
178 179
	dataprovider.SetTempPath(c.TempPath)
180 +
	transfersChecker = getTransfersChecker()
179 181
	return nil
180 182
}
181 183
@@ -267,41 +269,52 @@
Loading
267 269
}
268 270
269 271
// the ticker cannot be started/stopped from multiple goroutines
270 -
func startIdleTimeoutTicker(duration time.Duration) {
271 -
	stopIdleTimeoutTicker()
272 -
	idleTimeoutTicker = time.NewTicker(duration)
273 -
	idleTimeoutTickerDone = make(chan bool)
272 +
func startPeriodicTimeoutTicker(duration time.Duration) {
273 +
	stopPeriodicTimeoutTicker()
274 +
	periodicTimeoutTicker = time.NewTicker(duration)
275 +
	periodicTimeoutTickerDone = make(chan bool)
274 276
	go func() {
277 +
		counter := int64(0)
278 +
		ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
275 279
		for {
276 280
			select {
277 -
			case <-idleTimeoutTickerDone:
281 +
			case <-periodicTimeoutTickerDone:
278 282
				return
279 -
			case <-idleTimeoutTicker.C:
280 -
				Connections.checkIdles()
283 +
			case <-periodicTimeoutTicker.C:
284 +
				counter++
285 +
				if Config.IdleTimeout > 0 && counter >= int64(ratio) {
286 +
					counter = 0
287 +
					Connections.checkIdles()
288 +
				}
289 +
				go Connections.checkTransfers()
281 290
			}
282 291
		}
283 292
	}()
284 293
}
285 294
286 -
func stopIdleTimeoutTicker() {
287 -
	if idleTimeoutTicker != nil {
288 -
		idleTimeoutTicker.Stop()
289 -
		idleTimeoutTickerDone <- true
290 -
		idleTimeoutTicker = nil
295 +
func stopPeriodicTimeoutTicker() {
296 +
	if periodicTimeoutTicker != nil {
297 +
		periodicTimeoutTicker.Stop()
298 +
		periodicTimeoutTickerDone <- true
299 +
		periodicTimeoutTicker = nil
291 300
	}
292 301
}
293 302
294 303
// ActiveTransfer defines the interface for the current active transfers
295 304
type ActiveTransfer interface {
296 -
	GetID() uint64
305 +
	GetID() int64
297 306
	GetType() int
298 307
	GetSize() int64
308 +
	GetDownloadedSize() int64
309 +
	GetUploadedSize() int64
299 310
	GetVirtualPath() string
300 311
	GetStartTime() time.Time
301 -
	SignalClose()
312 +
	SignalClose(err error)
302 313
	Truncate(fsPath string, size int64) (int64, error)
303 314
	GetRealFsPath(fsPath string) string
304 315
	SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
316 +
	GetTruncatedSize() int64
317 +
	GetMaxAllowedSize() int64
305 318
}
306 319
307 320
// ActiveConnection defines the interface for the current active connections
@@ -319,6 +332,7 @@
Loading
319 332
	AddTransfer(t ActiveTransfer)
320 333
	RemoveTransfer(t ActiveTransfer)
321 334
	GetTransfers() []ConnectionTransfer
335 +
	SignalTransferClose(transferID int64, err error)
322 336
	CloseFS() error
323 337
}
324 338
@@ -335,11 +349,14 @@
Loading
335 349
336 350
// ConnectionTransfer defines the trasfer details to expose
337 351
type ConnectionTransfer struct {
338 -
	ID            uint64 `json:"-"`
339 -
	OperationType string `json:"operation_type"`
340 -
	StartTime     int64  `json:"start_time"`
341 -
	Size          int64  `json:"size"`
342 -
	VirtualPath   string `json:"path"`
352 +
	ID             int64  `json:"-"`
353 +
	OperationType  string `json:"operation_type"`
354 +
	StartTime      int64  `json:"start_time"`
355 +
	Size           int64  `json:"size"`
356 +
	VirtualPath    string `json:"path"`
357 +
	MaxAllowedSize int64  `json:"-"`
358 +
	ULSize         int64  `json:"-"`
359 +
	DLSize         int64  `json:"-"`
343 360
}
344 361
345 362
func (t *ConnectionTransfer) getConnectionTransferAsString() string {
@@ -653,7 +670,8 @@
Loading
653 670
type ActiveConnections struct {
654 671
	// clients contains both authenticated and estabilished connections and the ones waiting
655 672
	// for authentication
656 -
	clients clientsMap
673 +
	clients              clientsMap
674 +
	transfersCheckStatus int32
657 675
	sync.RWMutex
658 676
	connections    []ActiveConnection
659 677
	sshConnections []*SSHConnection
@@ -825,6 +843,59 @@
Loading
825 843
	conns.RUnlock()
826 844
}
827 845
846 +
func (conns *ActiveConnections) checkTransfers() {
847 +
	if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
848 +
		logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
849 +
		return
850 +
	}
851 +
	atomic.StoreInt32(&conns.transfersCheckStatus, 1)
852 +
	defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
853 +
854 +
	var wg sync.WaitGroup
855 +
856 +
	logger.Debug(logSender, "", "start concurrent transfers check")
857 +
	conns.RLock()
858 +
859 +
	// update the current size for transfers to monitors
860 +
	for _, c := range conns.connections {
861 +
		for _, t := range c.GetTransfers() {
862 +
			if t.MaxAllowedSize > 0 {
863 +
				wg.Add(1)
864 +
865 +
				go func(transfer ConnectionTransfer, connID string) {
866 +
					defer wg.Done()
867 +
					transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
868 +
				}(t, c.GetID())
869 +
			}
870 +
		}
871 +
	}
872 +
873 +
	conns.RUnlock()
874 +
	logger.Debug(logSender, "", "waiting for the update of the transfers current size")
875 +
	wg.Wait()
876 +
877 +
	logger.Debug(logSender, "", "getting overquota transfers")
878 +
	overquotaTransfers := transfersChecker.GetOverquotaTransfers()
879 +
	logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
880 +
	if len(overquotaTransfers) == 0 {
881 +
		return
882 +
	}
883 +
884 +
	conns.RLock()
885 +
	defer conns.RUnlock()
886 +
887 +
	for _, c := range conns.connections {
888 +
		for _, overquotaTransfer := range overquotaTransfers {
889 +
			if c.GetID() == overquotaTransfer.ConnID {
890 +
				logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
891 +
					c.GetUsername(), overquotaTransfer.TransferID)
892 +
				c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
893 +
			}
894 +
		}
895 +
	}
896 +
	logger.Debug(logSender, "", "transfers check completed")
897 +
}
898 +
828 899
// AddClientConnection stores a new client connection
829 900
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
830 901
	conns.clients.add(ipAddr)

@@ -335,8 +335,8 @@
Loading
335 335
		return nil, c.GetFsError(fs, err)
336 336
	}
337 337
338 -
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, common.TransferDownload,
339 -
		0, 0, 0, false, fs)
338 +
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath,
339 +
		common.TransferDownload, 0, 0, 0, 0, false, fs)
340 340
	baseTransfer.SetFtpMode(c.getFTPMode())
341 341
	t := newTransfer(baseTransfer, nil, r, offset)
342 342
@@ -402,7 +402,7 @@
Loading
402 402
	maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported())
403 403
404 404
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
405 -
		common.TransferUpload, 0, 0, maxWriteSize, true, fs)
405 +
		common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs)
406 406
	baseTransfer.SetFtpMode(c.getFTPMode())
407 407
	t := newTransfer(baseTransfer, w, nil, 0)
408 408
@@ -452,6 +452,7 @@
Loading
452 452
	}
453 453
454 454
	initialSize := int64(0)
455 +
	truncatedSize := int64(0) // bytes truncated and not included in quota
455 456
	if isResume {
456 457
		c.Log(logger.LevelDebug, "resuming upload requested, file path: %#v initial size: %v", filePath, fileSize)
457 458
		minWriteOffset = fileSize
@@ -473,13 +474,14 @@
Loading
473 474
			}
474 475
		} else {
475 476
			initialSize = fileSize
477 +
			truncatedSize = fileSize
476 478
		}
477 479
	}
478 480
479 481
	vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID())
480 482
481 483
	baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath,
482 -
		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, fs)
484 +
		common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs)
483 485
	baseTransfer.SetFtpMode(c.getFTPMode())
484 486
	t := newTransfer(baseTransfer, w, nil, 0)
485 487
Files Coverage
common 99.91%
config 100.00%
ftpd 100.00%
httpd 99.95%
mfa 100.00%
sftpd 98.52%
telemetry 100.00%
webdavd 100.00%
Project Totals (63 files) 99.72%
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading