Files

153 lines
2.8 KiB
Go

package osimage
import (
"context"
"fmt"
"io"
"os"
"time"
)
const (
defaultWriteBufferSize = 1 * 1024 * 1024
defaultMinWriteBPS = int64(2 * 1024 * 1024)
defaultInitialWriteBPS = int64(4 * 1024 * 1024)
defaultMaxWriteBPS = int64(8 * 1024 * 1024)
defaultBurstBytes = int64(512 * 1024)
defaultSampleInterval = 250 * time.Millisecond
defaultSyncEveryBytes = 0
defaultBusyHighPct = 80.0
defaultBusyLowPct = 40.0
defaultSlowAwait = 20 * time.Millisecond
defaultFastAwait = 5 * time.Millisecond
)
func WriteStreamToTarget(ctx context.Context,
src io.Reader, targetPath string,
expectedSize int64, bufferSize int,
progress ProgressFunc,
) (int64, error) {
if targetPath == "" {
return 0, fmt.Errorf("target path is required")
}
if bufferSize <= 0 {
bufferSize = defaultWriteBufferSize
}
f, err := os.OpenFile(targetPath, os.O_WRONLY, 0)
if err != nil {
return 0, fmt.Errorf("open target: %w", err)
}
defer f.Close()
ctrl, err := newAdaptiveWriteController(targetPath)
if err != nil {
ctrl = newNoopAdaptiveWriteController()
}
written, err := copyWithProgressBuffer(
ctx,
f,
src,
expectedSize,
"flash",
progress,
make([]byte, bufferSize),
ctrl,
defaultSyncEveryBytes,
)
if err != nil {
return written, err
}
if expectedSize > 0 && written != expectedSize {
return written, fmt.Errorf("written size mismatch: got %d want %d", written, expectedSize)
}
if err := f.Sync(); err != nil {
return written, fmt.Errorf("sync target: %w", err)
}
return written, nil
}
func copyWithProgressBuffer(
ctx context.Context,
dst *os.File,
src io.Reader,
total int64,
stage string,
progress ProgressFunc,
buf []byte,
ctrl *adaptiveWriteController,
syncEvery int64,
) (int64, error) {
var written int64
var sinceSync int64
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, er := src.Read(buf)
if nr > 0 {
if ctrl != nil {
if err := ctrl.Wait(ctx, nr); err != nil {
return written, err
}
}
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
sinceSync += int64(nw)
if ctrl != nil {
ctrl.ObserveWrite(nw)
}
if progress != nil {
progress(Progress{
Stage: stage,
BytesComplete: written,
BytesTotal: total,
})
}
if syncEvery > 0 && sinceSync >= syncEvery {
if err := dst.Sync(); err != nil {
return written, fmt.Errorf("periodic sync target: %w", err)
}
sinceSync = 0
if ctrl != nil {
ctrl.ObserveSync()
}
}
}
if ew != nil {
return written, ew
}
if nw != nr {
return written, io.ErrShortWrite
}
}
if er != nil {
if er == io.EOF {
return written, nil
}
return written, fmt.Errorf("copy %s: %w", stage, er)
}
}
}