Writes and verify image

This commit is contained in:
2026-04-04 02:45:46 +08:00
parent 9cb593ffc0
commit 4f490ab37e
20 changed files with 596 additions and 65 deletions

View File

@@ -0,0 +1,42 @@
package osimage
import (
"context"
"fmt"
)
func ApplyImageStreamed(ctx context.Context, opts ApplyOptions) (*ApplyResult, error) {
if err := ValidateApplyOptions(opts); err != nil {
return nil, err
}
if err := CheckTargetSafe(opts.TargetPath, opts.ExpectedRawSize); err != nil {
return nil, fmt.Errorf("unsafe target %q: %w", opts.TargetPath, err)
}
src, closeFn, err := OpenDecompressedHTTPStream(ctx, opts.URL, opts.HTTPTimeout)
if err != nil {
return nil, fmt.Errorf("open source stream: %w", err)
}
defer closeFn()
written, err := WriteStreamToTarget(ctx, src, opts.TargetPath, opts.ExpectedRawSize, opts.BufferSize, opts.Progress)
if err != nil {
return nil, fmt.Errorf("write target: %w", err)
}
sum, err := VerifyTargetSHA256(ctx, opts.TargetPath, opts.ExpectedRawSize, opts.BufferSize, opts.Progress)
if err != nil {
return nil, fmt.Errorf("verify target: %w", err)
}
if err := VerifySHA256(sum, opts.ExpectedRawSHA256); err != nil {
return nil, fmt.Errorf("final disk checksum mismatch: %w", err)
}
return &ApplyResult{
BytesWritten: written,
VerifiedSHA256: sum,
VerificationOK: true,
}, nil
}

View File

@@ -0,0 +1,15 @@
package osimage
func PercentOf(done, total int64) int64 {
if total <= 0 {
return 0
}
p := (done * 100) / total
if p < 0 {
return 0
}
if p > 100 {
return 100
}
return p
}

View File

@@ -0,0 +1,121 @@
package osimage
import (
"k8s.io/klog/v2"
"time"
)
type progressState struct {
lastTime time.Time
lastPercent int64
lastBucket int64
}
type ProgressLogger struct {
minInterval time.Duration
bucketSize int64
states map[string]*progressState
}
func NewProgressLogger(minSeconds int, bucketSize int64) *ProgressLogger {
if minSeconds < 0 {
minSeconds = 0
}
if bucketSize <= 0 {
bucketSize = 10
}
return &ProgressLogger{
minInterval: time.Duration(minSeconds) * time.Second,
bucketSize: bucketSize,
states: make(map[string]*progressState),
}
}
func (l *ProgressLogger) state(stage string) *progressState {
s, ok := l.states[stage]
if ok {
return s
}
s = &progressState{
lastPercent: -1,
lastBucket: -1,
}
l.states[stage] = s
return s
}
func (l *ProgressLogger) Log(p Progress) {
if p.BytesTotal <= 0 {
return
}
percent := PercentOf(p.BytesComplete, p.BytesTotal)
now := time.Now()
bucket := percent / l.bucketSize
s := l.state(p.Stage)
// Always log first visible progress
if s.lastPercent == -1 {
s.lastPercent = percent
s.lastBucket = bucket
s.lastTime = now
klog.V(4).InfoS(p.Stage, "progress", percent)
return
}
// Always log completion once
if percent == 100 && s.lastPercent < 100 {
s.lastPercent = 100
s.lastBucket = 100 / l.bucketSize
s.lastTime = now
klog.V(4).InfoS(p.Stage, "progress", 100)
return
}
// Log if we crossed a new milestone bucket
if bucket > s.lastBucket {
s.lastPercent = percent
s.lastBucket = bucket
s.lastTime = now
klog.V(4).InfoS(p.Stage, "progress", percent)
return
}
// Otherwise allow a timed refresh if progress moved
if now.Sub(s.lastTime) >= l.minInterval && percent > s.lastPercent {
s.lastPercent = percent
s.lastTime = now
klog.V(4).InfoS(p.Stage, "progress", percent)
}
}
type TimeBasedUpdater struct {
interval time.Duration
lastRun time.Time
}
func NewTimeBasedUpdater(seconds int) *TimeBasedUpdater {
if seconds <= 0 {
seconds = 15
}
return &TimeBasedUpdater{
interval: time.Duration(seconds) * time.Second,
}
}
func (u *TimeBasedUpdater) Run(fn func() error) error {
now := time.Now()
if !u.lastRun.IsZero() && now.Sub(u.lastRun) < u.interval {
return nil
}
if err := fn(); err != nil {
return err
}
u.lastRun = now
return nil
}

View File

@@ -0,0 +1,50 @@
package osimage
import (
"context"
"fmt"
"io"
"net/http"
"time"
"github.com/klauspost/compress/zstd"
)
func OpenDecompressedHTTPStream(ctx context.Context, url string, timeout time.Duration) (io.Reader, func() error, error) {
if url == "" {
return nil, nil, fmt.Errorf("url is required")
}
if timeout <= 0 {
timeout = 30 * time.Minute
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, nil, fmt.Errorf("build request: %w", err)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("http get %q: %w", url, err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, nil, fmt.Errorf("unexpected status: %s", resp.Status)
}
dec, err := zstd.NewReader(resp.Body)
if err != nil {
resp.Body.Close()
return nil, nil, fmt.Errorf("create zstd decoder: %w", err)
}
closeFn := func() error {
dec.Close()
return resp.Body.Close()
}
return dec, closeFn, nil
}

View File

@@ -0,0 +1,29 @@
package osimage
import "time"
type ApplyOptions struct {
URL string
TargetPath string
ExpectedRawSHA256 string
ExpectedRawSize int64
HTTPTimeout time.Duration
BufferSize int
Progress ProgressFunc
}
type Progress struct {
Stage string
BytesComplete int64
BytesTotal int64
}
type ProgressFunc func(Progress)
type ApplyResult struct {
BytesWritten int64
VerifiedSHA256 string
VerificationOK bool
}

View File

@@ -0,0 +1,104 @@
package osimage
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"strings"
)
func VerifyTargetSHA256(ctx context.Context, targetPath string, expectedSize int64,
bufferSize int, progress ProgressFunc) (string, error) {
if targetPath == "" {
return "", fmt.Errorf("target path is required")
}
if expectedSize <= 0 {
return "", fmt.Errorf("expected raw size is required for verification")
}
if bufferSize <= 0 {
bufferSize = 4 * 1024 * 1024
}
f, err := os.Open(targetPath)
if err != nil {
return "", fmt.Errorf("open target for verify: %w", err)
}
defer f.Close()
h := sha256.New()
buf := make([]byte, bufferSize)
var readTotal int64
limited := io.LimitReader(f, expectedSize)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := limited.Read(buf)
if n > 0 {
if _, werr := h.Write(buf[:n]); werr != nil {
return "", fmt.Errorf("hash target: %w", werr)
}
readTotal += int64(n)
if progress != nil {
progress(Progress{
Stage: "verify",
BytesComplete: readTotal,
BytesTotal: expectedSize,
})
}
}
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("read target: %w", err)
}
}
if readTotal != expectedSize {
return "", fmt.Errorf("verify size mismatch: got %d want %d", readTotal, expectedSize)
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func ValidateApplyOptions(opts ApplyOptions) error {
if opts.URL == "" {
return fmt.Errorf("url is required")
}
if opts.TargetPath == "" {
return fmt.Errorf("target path is required")
}
if opts.ExpectedRawSHA256 == "" {
return fmt.Errorf("expected raw sha256 is required")
}
if opts.ExpectedRawSize <= 0 {
return fmt.Errorf("expected raw size must be > 0")
}
return nil
}
func VerifySHA256(got, expected string) error {
expected = NormalizeSHA256(expected)
if expected == "" {
return nil
}
got = NormalizeSHA256(got)
if got != expected {
return fmt.Errorf("sha256 mismatch: got %s want %s", got, expected)
}
return nil
}
func NormalizeSHA256(s string) string {
return strings.ToLower(strings.TrimSpace(s))
}

View File

@@ -0,0 +1,35 @@
//go:build !dev
package osimage
import (
"fmt"
"os"
"strings"
)
func CheckTargetSafe(targetPath string, expectedRawSize int64) error {
if !strings.HasPrefix(targetPath, "/dev/") {
return fmt.Errorf("target must be a device path under /dev")
}
st, err := os.Stat(targetPath)
if err != nil {
return fmt.Errorf("stat target: %w", err)
}
mode := st.Mode()
if mode&os.ModeDevice == 0 {
return fmt.Errorf("target is not a device")
}
// TODO: Add stronger checks
// - EnsureNotMounted(targetPath)
// - EnsureNotCurrentRoot(targetPath)
// - EnsurePartitionNotWholeDisk(targetPath)
// - EnsureCapacity(targetPath, expectedRawSize)
_ = expectedRawSize
return nil
}

View File

@@ -0,0 +1,7 @@
//go:build dev
package osimage
func CheckTargetSafe(targetPath string, expectedRawSize int64) error {
return nil
}

View File

@@ -0,0 +1,82 @@
package osimage
import (
"context"
"fmt"
"io"
"os"
)
func WriteStreamToTarget(ctx context.Context,
src io.Reader,
targetPath string,
expectedSize int64, bufferSize int,
progress ProgressFunc,
) (int64, error) {
if targetPath == "" {
return 0, fmt.Errorf("target path is required")
}
if bufferSize <= 0 {
bufferSize = 4 * 1024 * 1024
}
f, err := os.OpenFile(targetPath, os.O_WRONLY, 0)
if err != nil {
return 0, fmt.Errorf("open target: %w", err)
}
defer f.Close()
written, err := copyWithProgressBuffer(ctx, f, src, expectedSize, "flash", progress, make([]byte, bufferSize))
if err != nil {
return written, err
}
if expectedSize > 0 && written != expectedSize {
return written, fmt.Errorf("written size mismatch: got %d want %d", written, expectedSize)
}
if err := f.Sync(); err != nil {
return written, fmt.Errorf("sync target: %w", err)
}
return written, nil
}
func copyWithProgressBuffer(ctx context.Context, dst io.Writer, src io.Reader, total int64, stage string, progress ProgressFunc, buf []byte) (int64, error) {
var written int64
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if progress != nil {
progress(Progress{
Stage: stage,
BytesComplete: written,
BytesTotal: total,
})
}
}
if ew != nil {
return written, ew
}
if nw != nr {
return written, io.ErrShortWrite
}
}
if er != nil {
if er == io.EOF {
return written, nil
}
return written, fmt.Errorf("copy %s: %w", stage, er)
}
}
}