83 lines
1.6 KiB
Go
83 lines
1.6 KiB
Go
package osimage
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
)
|
|
|
|
func WriteStreamToTarget(ctx context.Context,
|
|
src io.Reader,
|
|
targetPath string,
|
|
expectedSize int64, bufferSize int,
|
|
progress ProgressFunc,
|
|
) (int64, error) {
|
|
if targetPath == "" {
|
|
return 0, fmt.Errorf("target path is required")
|
|
}
|
|
if bufferSize <= 0 {
|
|
bufferSize = 4 * 1024 * 1024
|
|
}
|
|
|
|
f, err := os.OpenFile(targetPath, os.O_WRONLY, 0)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("open target: %w", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
written, err := copyWithProgressBuffer(ctx, f, src, expectedSize, "flash", progress, make([]byte, bufferSize))
|
|
if err != nil {
|
|
return written, err
|
|
}
|
|
|
|
if expectedSize > 0 && written != expectedSize {
|
|
return written, fmt.Errorf("written size mismatch: got %d want %d", written, expectedSize)
|
|
}
|
|
|
|
if err := f.Sync(); err != nil {
|
|
return written, fmt.Errorf("sync target: %w", err)
|
|
}
|
|
|
|
return written, nil
|
|
}
|
|
|
|
func copyWithProgressBuffer(ctx context.Context, dst io.Writer, src io.Reader, total int64, stage string, progress ProgressFunc, buf []byte) (int64, error) {
|
|
var written int64
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return written, ctx.Err()
|
|
default:
|
|
}
|
|
|
|
nr, er := src.Read(buf)
|
|
if nr > 0 {
|
|
nw, ew := dst.Write(buf[:nr])
|
|
if nw > 0 {
|
|
written += int64(nw)
|
|
if progress != nil {
|
|
progress(Progress{
|
|
Stage: stage,
|
|
BytesComplete: written,
|
|
BytesTotal: total,
|
|
})
|
|
}
|
|
}
|
|
if ew != nil {
|
|
return written, ew
|
|
}
|
|
if nw != nr {
|
|
return written, io.ErrShortWrite
|
|
}
|
|
}
|
|
if er != nil {
|
|
if er == io.EOF {
|
|
return written, nil
|
|
}
|
|
return written, fmt.Errorf("copy %s: %w", stage, er)
|
|
}
|
|
}
|
|
}
|